From ea06e6ca7b5b0f5aad15ac443db04a264a80b87f Mon Sep 17 00:00:00 2001 From: Xinye Date: Wed, 25 Oct 2023 11:48:25 +0800 Subject: [PATCH] support async callback and pause Signed-off-by: Xinye --- Cargo.toml | 6 ++ src/lib.rs | 213 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 216 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 52244e8..6c50988 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,12 +14,18 @@ edition = "2021" exclude = ["/.github/*", "/.travis.yml", "/appveyor.yml"] [dependencies] +futures = { version = "0.3", optional = true } log = { version = "0.4", features = ["std"] } once_cell = "1.9.0" rand = "0.8" +tokio = { version = "1.32", features = [ "sync" ] } + +[dev-dependencies] +tokio = { version = "1.32", features = [ "sync", "rt-multi-thread", "time", "macros" ] } [features] failpoints = [] +async = [ "futures" ] [package.metadata.docs.rs] all-features = true diff --git a/src/lib.rs b/src/lib.rs index f23cc44..65ff011 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -234,6 +234,162 @@ use std::sync::{Arc, Condvar, Mutex, MutexGuard, RwLock, TryLockError}; use std::time::{Duration, Instant}; use std::{env, thread}; +#[cfg(feature = "async")] +mod async_imp { + use super::*; + use futures::future::BoxFuture; + + #[derive(Clone)] + pub(crate) struct AsyncCallback( + Arc BoxFuture<'static, ()> + Send + Sync + 'static>, + ); + + impl Debug for AsyncCallback { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("AsyncCallback()") + } + } + + impl PartialEq for AsyncCallback { + #[allow(clippy::vtable_address_comparisons)] + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } + } + + impl AsyncCallback { + fn new(f: impl Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static) -> AsyncCallback { + AsyncCallback(Arc::new(f)) + } + + async fn run(&self) { + let callback = &self.0; + callback().await; + } + } + + /// `fail_point` but with support for async callback. + #[macro_export] + #[cfg(all(feature = "failpoints", feature = "async"))] + macro_rules! async_fail_point { + ($name:expr) => {{ + $crate::async_eval($name, |_| { + panic!("Return is not supported for the fail point \"{}\"", $name); + }) + .await; + }}; + ($name:expr, $e:expr) => {{ + if let Some(res) = $crate::async_eval($name, $e).await { + return res; + } + }}; + ($name:expr, $cond:expr, $e:expr) => {{ + if $cond { + $crate::async_fail_point!($name, $e); + } + }}; + } + + /// Configures an async callback to be triggered at the specified + /// failpoint. If the failpoint is not implemented using + /// `async_fail_point`, the execution will raise an exception. + pub fn cfg_async_callback(name: S, f: F) -> Result<(), String> + where + S: Into, + F: Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static, + { + let mut registry = REGISTRY.registry.write().unwrap(); + let p = registry + .entry(name.into()) + .or_insert_with(|| Arc::new(FailPoint::new())); + let action = Action::from_async_callback(f); + let actions = vec![action]; + p.set_actions("callback", actions); + Ok(()) + } + + #[doc(hidden)] + pub async fn async_eval) -> R>(name: &str, f: F) -> Option { + let p = { + let registry = REGISTRY.registry.read().unwrap(); + match registry.get(name) { + None => return None, + Some(p) => p.clone(), + } + }; + p.async_eval(name).await.map(f) + } + + impl Action { + #[cfg(feature = "async")] + fn from_async_callback( + f: impl Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static, + ) -> Action { + let task = Task::CallbackAsync(AsyncCallback::new(f)); + Action { + task, + freq: 1.0, + count: None, + } + } + } + + impl FailPoint { + #[cfg_attr(feature = "cargo-clippy", allow(clippy::option_option))] + async fn async_eval(&self, name: &str) -> Option> { + let task = { + let task = self + .actions + .read() + .unwrap() + .iter() + .filter_map(Action::get_task) + .next(); + match task { + Some(Task::Pause) => { + // let n = self.async_pause_notify.clone(); + self.async_pause_notify.notified().await; + return None; + } + Some(t) => t, + None => return None, + } + }; + + match task { + Task::Off => {} + Task::Return(s) => return Some(s), + Task::Sleep(_) => panic!( + "fail does not support async sleep, please use a async closure to sleep." + ), + Task::Panic(msg) => match msg { + Some(ref msg) => panic!("{}", msg), + None => panic!("failpoint {} panic", name), + }, + Task::Print(msg) => match msg { + Some(ref msg) => log::info!("{}", msg), + None => log::info!("failpoint {} executed.", name), + }, + Task::Pause => unreachable!(), + Task::Yield => thread::yield_now(), + Task::Delay(_) => panic!( + "fail does not support async delay, please use a async closure to sleep." + ), + Task::Callback(f) => { + f.run(); + } + Task::CallbackAsync(f) => { + f.run().await; + } + } + None + } + } +} + +#[cfg(feature = "async")] +pub use async_imp::*; + #[derive(Clone)] struct SyncCallback(Arc); @@ -282,6 +438,8 @@ enum Task { Delay(u64), /// Call callback function. Callback(SyncCallback), + #[cfg(feature = "async")] + CallbackAsync(async_imp::AsyncCallback), } #[derive(Debug)] @@ -433,6 +591,8 @@ impl FromStr for Action { struct FailPoint { pause: Mutex, pause_notifier: Condvar, + #[cfg(feature = "async")] + async_pause_notify: tokio::sync::Notify, actions: RwLock>, actions_str: RwLock, } @@ -443,6 +603,8 @@ impl FailPoint { FailPoint { pause: Mutex::new(false), pause_notifier: Condvar::new(), + #[cfg(feature = "async")] + async_pause_notify: tokio::sync::Notify::new(), actions: RwLock::default(), actions_str: RwLock::default(), } @@ -450,6 +612,7 @@ impl FailPoint { fn set_actions(&self, actions_str: &str, actions: Vec) { loop { + self.async_pause_notify.notify_waiters(); // TODO: maybe busy waiting here. match self.actions.try_write() { Err(TryLockError::WouldBlock) => {} @@ -460,9 +623,11 @@ impl FailPoint { } Err(e) => panic!("unexpected poison: {:?}", e), } - let mut guard = self.pause.lock().unwrap(); - *guard = false; - self.pause_notifier.notify_all(); + { + let mut guard = self.pause.lock().unwrap(); + *guard = false; + self.pause_notifier.notify_all(); + } } } @@ -509,6 +674,7 @@ impl FailPoint { Task::Callback(f) => { f.run(); } + Task::CallbackAsync(_) => unreachable!(), } None } @@ -1062,4 +1228,45 @@ mod tests { assert_eq!(rx.recv_timeout(Duration::from_millis(500)).unwrap(), 0); assert_eq!(f1(), 0); } + + #[cfg_attr(not(all(feature = "failpoints", feature = "async")), ignore)] + #[tokio::test] + async fn test_async_failpoint() { + use std::time::Duration; + + let f1 = async { + async_fail_point!("cb"); + }; + let f2 = async { + async_fail_point!("cb"); + }; + + let counter = Arc::new(AtomicUsize::new(0)); + let counter2 = counter.clone(); + cfg_async_callback("cb", move || { + counter2.fetch_add(1, Ordering::SeqCst); + Box::pin(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + }) + }) + .unwrap(); + f1.await; + f2.await; + assert_eq!(2, counter.load(Ordering::SeqCst)); + + cfg("pause", "pause").unwrap(); + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + let handle = tokio::spawn(async move { + async_fail_point!("pause"); + tx.send(()).await.unwrap(); + }); + tokio::time::timeout(Duration::from_millis(500), rx.recv()) + .await + .unwrap_err(); + remove("pause"); + tokio::time::timeout(Duration::from_millis(500), rx.recv()) + .await + .unwrap(); + handle.await.unwrap(); + } }