Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support async failpoints #73

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@ exclude = ["/.github/*", "/.travis.yml", "/appveyor.yml"]
log = { version = "0.4", features = ["std"] }
once_cell = "1.9.0"
rand = "0.8"
tokio = { version = "1.32", features = ["sync"], optional = true }

[dev-dependencies]
tokio = { version = "1.32", features = ["sync", "rt-multi-thread", "time", "macros"] }

[features]
failpoints = []
async = ["tokio"]

[package.metadata.docs.rs]
all-features = true
241 changes: 241 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ enum Task {
Delay(u64),
/// Call callback function.
Callback(SyncCallback),
#[cfg(feature = "async")]
CallbackAsync(async_imp::AsyncCallback),
}

#[derive(Debug)]
Expand Down Expand Up @@ -433,6 +435,8 @@ impl FromStr for Action {
struct FailPoint {
pause: Mutex<bool>,
pause_notifier: Condvar,
#[cfg(feature = "async")]
async_pause_notify: tokio::sync::Notify,
actions: RwLock<Vec<Action>>,
actions_str: RwLock<String>,
}
Expand All @@ -443,13 +447,17 @@ impl FailPoint {
FailPoint {
pause: Mutex::new(false),
pause_notifier: Condvar::new(),
#[cfg(feature = "async")]
async_pause_notify: tokio::sync::Notify::new(),
actions: RwLock::default(),
actions_str: RwLock::default(),
}
}

fn set_actions(&self, actions_str: &str, actions: Vec<Action>) {
loop {
#[cfg(feature = "async")]
self.async_pause_notify.notify_waiters();
// TODO: maybe busy waiting here.
match self.actions.try_write() {
Err(TryLockError::WouldBlock) => {}
Expand Down Expand Up @@ -509,6 +517,10 @@ impl FailPoint {
Task::Callback(f) => {
f.run();
}
#[cfg(feature = "async")]
Task::CallbackAsync(_) => panic!(
"to use async callback, please enable `async` feature and use `async_fail_point`"
),
}
None
}
Expand Down Expand Up @@ -852,6 +864,179 @@ macro_rules! fail_point {
($name:expr, $cond:expr, $e:expr) => {{}};
}

#[cfg(feature = "async")]
mod async_imp {
use super::*;
type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;

#[derive(Clone)]
pub(crate) struct AsyncCallback(
Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static>,
);

impl Debug for AsyncCallback {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("AsyncCallback()")
}
}

impl PartialEq for AsyncCallback {
#[allow(clippy::vtable_address_comparisons)]
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}

impl AsyncCallback {
fn new(f: impl Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static) -> AsyncCallback {
AsyncCallback(Arc::new(f))
}

async fn run(&self) {
let callback = &self.0;
callback().await;
}
}

/// `fail_point` but with support for async callback and pause.
#[macro_export]
#[cfg(feature = "failpoints")]
macro_rules! async_fail_point {
($name:expr) => {{
$crate::async_eval($name, |_| {
panic!("Return is not supported for the fail point \"{}\"", $name);
})
.await;
}};
($name:expr, $e:expr) => {{
if let Some(res) = $crate::async_eval($name, $e).await {
return res;
}
}};
($name:expr, $cond:expr, $e:expr) => {{
if $cond {
$crate::async_fail_point!($name, $e);
}
}};
}

/// Define an async fail point (disabled, see `failpoints` feature).
#[macro_export]
#[cfg(not(feature = "failpoints"))]
macro_rules! async_fail_point {
($name:expr, $e:expr) => {{}};
($name:expr) => {{}};
($name:expr, $cond:expr, $e:expr) => {{}};
}

/// Configures an async callback to be triggered at the specified
/// failpoint. If the failpoint is not implemented using
/// `async_fail_point`, the execution will raise an exception.
pub fn cfg_async_callback<S, F>(name: S, f: F) -> Result<(), String>
where
S: Into<String>,
F: Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static,
{
let mut registry = REGISTRY.registry.write().unwrap();
let p = registry
.entry(name.into())
.or_insert_with(|| Arc::new(FailPoint::new()));
let action = Action::from_async_callback(f);
let actions = vec![action];
p.set_actions("callback", actions);
Ok(())
}

#[doc(hidden)]
pub async fn async_eval<R, F: FnOnce(Option<String>) -> R>(name: &str, f: F) -> Option<R> {
let p = {
let registry = REGISTRY.registry.read().unwrap();
match registry.get(name) {
None => return None,
Some(p) => p.clone(),
}
};
p.async_eval(name).await.map(f)
}

impl Action {
fn from_async_callback(
f: impl Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static,
) -> Action {
let task = Task::CallbackAsync(AsyncCallback::new(f));
Action {
task,
freq: 1.0,
count: None,
}
}
}

impl FailPoint {
#[cfg_attr(feature = "cargo-clippy", allow(clippy::option_option))]
async fn async_eval(&self, name: &str) -> Option<Option<String>> {
let task = {
let task = self
.actions
.read()
.unwrap()
.iter()
.filter_map(Action::get_task)
.next();
match task {
Some(Task::Pause) => {
// let n = self.async_pause_notify.clone();
self.async_pause_notify.notified().await;
return None;
}
Some(t) => t,
None => return None,
}
};

match task {
Task::Off => {}
Task::Return(s) => return Some(s),
Task::Sleep(t) => {
let not = Arc::new(tokio::sync::Notify::new());
let not_for_thread = not.clone();
let handle = std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(t));
not_for_thread.notify_waiters();
});
not.notified().await;
handle.join().unwrap();
}
Task::Panic(msg) => match msg {
Some(ref msg) => panic!("{}", msg),
None => panic!("failpoint {} panic", name),
},
Task::Print(msg) => match msg {
Some(ref msg) => log::info!("{}", msg),
None => log::info!("failpoint {} executed.", name),
},
Task::Pause => unreachable!(),
Task::Yield => thread::yield_now(),
Task::Delay(t) => {
let timer = Instant::now();
let timeout = Duration::from_millis(t);
while timer.elapsed() < timeout {}
}
Task::Callback(f) => {
f.run();
}
Task::CallbackAsync(f) => {
f.run().await;
}
}
None
}
}
}

#[cfg(feature = "async")]
pub use async_imp::*;

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -1062,4 +1247,60 @@ mod tests {
assert_eq!(rx.recv_timeout(Duration::from_millis(500)).unwrap(), 0);
assert_eq!(f1(), 0);
}

#[cfg(feature = "async")]
#[cfg_attr(not(feature = "failpoints"), ignore)]
#[tokio::test]
async fn test_async_failpoints() {
let f1 = async {
async_fail_point!("async_cb");
};
let f2 = async {
async_fail_point!("async_cb");
};

let counter = Arc::new(AtomicUsize::new(0));
let counter2 = counter.clone();
cfg_async_callback("async_cb", move || {
counter2.fetch_add(1, Ordering::SeqCst);
Box::pin(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
})
})
.unwrap();
f1.await;
f2.await;
assert_eq!(2, counter.load(Ordering::SeqCst));

cfg("async_pause", "pause").unwrap();
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let handle = tokio::spawn(async move {
async_fail_point!("async_pause");
tx.send(()).await.unwrap();
});
tokio::time::timeout(Duration::from_millis(500), rx.recv())
.await
.unwrap_err();
remove("async_pause");
tokio::time::timeout(Duration::from_millis(500), rx.recv())
.await
.unwrap();
handle.await.unwrap();

cfg("async_sleep", "sleep(500)").unwrap();
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let handle = tokio::spawn(async move {
tx.send(()).await.unwrap();
async_fail_point!("async_sleep");
tx.send(()).await.unwrap();
});
rx.recv().await.unwrap();
tokio::time::timeout(Duration::from_millis(300), rx.recv())
.await
.unwrap_err();
tokio::time::timeout(Duration::from_millis(300), rx.recv())
.await
.unwrap();
handle.await.unwrap();
}
}
Loading