diff --git a/src/lib.rs b/src/lib.rs index 6889bb0..227831a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -535,10 +535,11 @@ struct FailPointRegistry { registry: RwLock, } -use once_cell::sync::Lazy; +use once_cell::sync::{Lazy, OnceCell}; -static REGISTRY: Lazy = Lazy::new(FailPointRegistry::default); -static SCENARIO: Lazy> = Lazy::new(|| Mutex::new(®ISTRY)); +static REGISTRY: OnceCell = OnceCell::new(); +static SCENARIO: Lazy> = + Lazy::new(|| Mutex::new(REGISTRY.get_or_init(Default::default))); /// Test scenario with configured fail points. #[derive(Debug)] @@ -636,7 +637,11 @@ pub const fn has_failpoints() -> bool { /// /// Return a vector of `(name, actions)` pairs. pub fn list() -> Vec<(String, String)> { - let registry = REGISTRY.registry.read().unwrap(); + let registry = if let Some(r) = REGISTRY.get() { + r.registry.read().unwrap() + } else { + return Vec::new(); + }; registry .iter() .map(|(name, fp)| (name.to_string(), fp.actions_str.read().unwrap().clone())) @@ -645,8 +650,13 @@ pub fn list() -> Vec<(String, String)> { #[doc(hidden)] pub fn eval) -> R>(name: &str, f: F) -> Option { + let registry = if let Some(r) = REGISTRY.get() { + &r.registry + } else { + return None; + }; let p = { - let registry = REGISTRY.registry.read().unwrap(); + let registry = registry.read().unwrap(); match registry.get(name) { None => return None, Some(p) => p.clone(), @@ -686,7 +696,11 @@ pub fn eval) -> R>(name: &str, f: F) -> Option { /// A call to `cfg` with a particular fail point name overwrites any existing actions for /// that fail point, including those set via the `FAILPOINTS` environment variable. pub fn cfg>(name: S, actions: &str) -> Result<(), String> { - let mut registry = REGISTRY.registry.write().unwrap(); + let mut registry = REGISTRY + .get_or_init(Default::default) + .registry + .write() + .unwrap(); set(&mut registry, name.into(), actions) } @@ -699,7 +713,11 @@ where S: Into, F: Fn() + Send + Sync + 'static, { - let mut registry = REGISTRY.registry.write().unwrap(); + let mut registry = REGISTRY + .get_or_init(Default::default) + .registry + .write() + .unwrap(); let p = registry .entry(name.into()) .or_insert_with(|| Arc::new(FailPoint::new())); @@ -713,7 +731,11 @@ where /// /// If the fail point doesn't exist, nothing will happen. pub fn remove>(name: S) { - let mut registry = REGISTRY.registry.write().unwrap(); + let mut registry = if let Some(r) = REGISTRY.get() { + r.registry.write().unwrap() + } else { + return; + }; if let Some(p) = registry.remove(name.as_ref()) { // wake up all pause failpoint. p.set_actions("", vec![]); @@ -937,7 +959,11 @@ mod async_imp { S: Into, F: Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static, { - let mut registry = REGISTRY.registry.write().unwrap(); + let mut registry = REGISTRY + .get_or_init(Default::default) + .registry + .write() + .unwrap(); let p = registry .entry(name.into()) .or_insert_with(|| Arc::new(FailPoint::new())); @@ -949,8 +975,13 @@ mod async_imp { #[doc(hidden)] pub async fn async_eval) -> R>(name: &str, f: F) -> Option { + let registry = if let Some(r) = REGISTRY.get() { + &r.registry + } else { + return None; + }; let p = { - let registry = REGISTRY.registry.read().unwrap(); + let registry = registry.read().unwrap(); match registry.get(name) { None => return None, Some(p) => p.clone(), @@ -1017,7 +1048,7 @@ mod async_imp { }, Task::Pause => unreachable!(), Task::Yield => thread::yield_now(), - Task::Delay(_) => { + Task::Delay(t) => { let timer = Instant::now(); let timeout = Duration::from_millis(t); while timer.elapsed() < timeout {} @@ -1251,19 +1282,17 @@ mod tests { #[cfg(feature = "async")] #[cfg_attr(not(feature = "failpoints"), ignore)] #[tokio::test] - async fn test_async_failpoint() { - use std::time::Duration; - + async fn test_async_failpoints() { let f1 = async { - async_fail_point!("cb"); + async_fail_point!("async_cb"); }; let f2 = async { - async_fail_point!("cb"); + async_fail_point!("async_cb"); }; let counter = Arc::new(AtomicUsize::new(0)); let counter2 = counter.clone(); - cfg_async_callback("cb", move || { + cfg_async_callback("async_cb", move || { counter2.fetch_add(1, Ordering::SeqCst); Box::pin(async move { tokio::time::sleep(Duration::from_millis(10)).await; @@ -1274,26 +1303,26 @@ mod tests { f2.await; assert_eq!(2, counter.load(Ordering::SeqCst)); - cfg("pause", "pause").unwrap(); + cfg("async_pause", "pause").unwrap(); let (tx, mut rx) = tokio::sync::mpsc::channel(1); let handle = tokio::spawn(async move { - async_fail_point!("pause"); + async_fail_point!("async_pause"); tx.send(()).await.unwrap(); }); tokio::time::timeout(Duration::from_millis(500), rx.recv()) .await .unwrap_err(); - remove("pause"); + remove("async_pause"); tokio::time::timeout(Duration::from_millis(500), rx.recv()) .await .unwrap(); handle.await.unwrap(); - cfg("sleep", "sleep(500)").unwrap(); + cfg("async_sleep", "sleep(500)").unwrap(); let (tx, mut rx) = tokio::sync::mpsc::channel(1); let handle = tokio::spawn(async move { tx.send(()).await.unwrap(); - async_fail_point!("sleep"); + async_fail_point!("async_sleep"); tx.send(()).await.unwrap(); }); rx.recv().await.unwrap();