diff --git a/openraft/src/raft/core_state.rs b/openraft/src/raft/core_state.rs index b6892af26..a02e3463c 100644 --- a/openraft/src/raft/core_state.rs +++ b/openraft/src/raft/core_state.rs @@ -1,6 +1,7 @@ use crate::error::Fatal; use crate::error::Infallible; use crate::type_config::alias::JoinHandleOf; +use crate::type_config::alias::WatchReceiverOf; use crate::RaftTypeConfig; /// The running state of RaftCore @@ -10,6 +11,9 @@ where C: RaftTypeConfig /// The RaftCore task is still running. Running(JoinHandleOf>>), + /// The RaftCore task is waiting for a signal to finish joining. + Joining(WatchReceiverOf), + /// The RaftCore task has finished. The return value of the task is stored. Done(Result>), } diff --git a/openraft/src/raft/mod.rs b/openraft/src/raft/mod.rs index dac8d8ab2..7a3350135 100644 --- a/openraft/src/raft/mod.rs +++ b/openraft/src/raft/mod.rs @@ -319,7 +319,7 @@ where C: RaftTypeConfig rx_data_metrics, rx_server_metrics, tx_shutdown: std::sync::Mutex::new(Some(tx_shutdown)), - core_state: Mutex::new(CoreState::Running(core_handle)), + core_state: std::sync::Mutex::new(CoreState::Running(core_handle)), snapshot: Mutex::new(None), }; @@ -828,7 +828,7 @@ where C: RaftTypeConfig tracing::debug!("{} receives result is error: {:?}", func_name!(), recv_res.is_err()); let Ok(v) = recv_res else { - if self.inner.is_core_running().await { + if self.inner.is_core_running() { return Ok(Err(InvalidStateMachineType::new::())); } else { let fatal = self.inner.get_core_stopped_error("receiving rx from RaftCore", None::<&'static str>).await; diff --git a/openraft/src/raft/raft_inner.rs b/openraft/src/raft/raft_inner.rs index 8f26333f4..18c139230 100644 --- a/openraft/src/raft/raft_inner.rs +++ b/openraft/src/raft/raft_inner.rs @@ -6,6 +6,8 @@ use std::sync::Arc; use tokio::sync::Mutex; use tracing::Level; +use crate::async_runtime::watch::WatchReceiver; +use crate::async_runtime::watch::WatchSender; use crate::async_runtime::MpscUnboundedSender; use crate::config::RuntimeConfig; use crate::core::raft_msg::external_command::ExternalCommand; @@ -16,11 +18,13 @@ use crate::error::RaftError; use crate::metrics::RaftDataMetrics; use crate::metrics::RaftServerMetrics; use crate::raft::core_state::CoreState; +use crate::type_config::alias::AsyncRuntimeOf; use crate::type_config::alias::MpscUnboundedSenderOf; use crate::type_config::alias::OneshotReceiverOf; use crate::type_config::alias::OneshotSenderOf; use crate::type_config::alias::WatchReceiverOf; use crate::type_config::AsyncRuntime; +use crate::type_config::TypeConfigExt; use crate::Config; use crate::OptionalSend; use crate::RaftMetrics; @@ -41,7 +45,7 @@ where C: RaftTypeConfig pub(in crate::raft) rx_server_metrics: WatchReceiverOf>, pub(in crate::raft) tx_shutdown: std::sync::Mutex>>, - pub(in crate::raft) core_state: Mutex>, + pub(in crate::raft) core_state: std::sync::Mutex>, /// The ongoing snapshot transmission. pub(in crate::raft) snapshot: Mutex>>, @@ -129,8 +133,8 @@ where C: RaftTypeConfig Ok(()) } - pub(in crate::raft) async fn is_core_running(&self) -> bool { - let state = self.core_state.lock().await; + pub(in crate::raft) fn is_core_running(&self) -> bool { + let state = self.core_state.lock().unwrap(); state.is_running() } @@ -145,7 +149,7 @@ where C: RaftTypeConfig // Retrieve the result. let core_res = { - let state = self.core_state.lock().await; + let state = self.core_state.lock().unwrap(); if let CoreState::Done(core_task_res) = &*state { core_task_res.clone() } else { @@ -170,15 +174,40 @@ where C: RaftTypeConfig /// Wait for `RaftCore` task to finish and record the returned value from the task. #[tracing::instrument(level = "debug", skip_all)] pub(in crate::raft) async fn join_core_task(&self) { - let mut state = self.core_state.lock().await; - match &mut *state { - CoreState::Running(handle) => { - let res = handle.await; - tracing::info!(res = debug(&res), "RaftCore exited"); + // Get the Running state of RaftCore, + // or an error if RaftCore has been in Joining state. + let running_res = { + let mut state = self.core_state.lock().unwrap(); + + match &*state { + CoreState::Running(_) => { + let (tx, rx) = C::watch_channel::(false); + + let prev = std::mem::replace(&mut *state, CoreState::Joining(rx)); + + let CoreState::Running(join_handle) = prev else { + unreachable!() + }; + + Ok((join_handle, tx)) + } + CoreState::Joining(watch_rx) => Err(watch_rx.clone()), + CoreState::Done(_) => { + // RaftCore has already finished exiting, nothing to do + return; + } + } + }; + + match running_res { + Ok((join_handle, tx)) => { + let join_res = join_handle.await; - let core_task_res = match res { + tracing::info!(res = debug(&join_res), "RaftCore exited"); + + let core_task_res = match join_res { Err(err) => { - if C::AsyncRuntime::is_panic(&err) { + if AsyncRuntimeOf::::is_panic(&err) { Err(Fatal::Panicked) } else { Err(Fatal::Stopped) @@ -187,10 +216,23 @@ where C: RaftTypeConfig Ok(returned_res) => returned_res, }; - *state = CoreState::Done(core_task_res); + { + let mut state = self.core_state.lock().unwrap(); + *state = CoreState::Done(core_task_res); + } + tx.send(true).ok(); } - CoreState::Done(_) => { - // RaftCore has already quit, nothing to do + Err(mut rx) => { + // Other thread is waiting for the core to finish. + loop { + let res = rx.changed().await; + if res.is_err() { + break; + } + if *rx.borrow_watched() { + break; + } + } } } }