Skip to content

Commit

Permalink
Refactor: Replace async Mutex for RaftInner.core_state with standar…
Browse files Browse the repository at this point in the history
…d Mutex and a watch channel

In this commit, when joining the `RaftCore` task, the `core_state` is
first switched from `Running` to `Joining`. The thread then blocks
while awaiting the completion of the `RaftCore` task. Any other threads
observing the `Joining` state will wait for the first thread to finish
by monitoring a `watch` channel created by the initial thread.
  • Loading branch information
drmingdrmer committed Jul 30, 2024
1 parent 69ab424 commit 5c82dfd
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 16 deletions.
4 changes: 4 additions & 0 deletions openraft/src/raft/core_state.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::error::Fatal;
use crate::error::Infallible;
use crate::type_config::alias::JoinHandleOf;
use crate::type_config::alias::WatchReceiverOf;
use crate::RaftTypeConfig;

/// The running state of RaftCore
Expand All @@ -10,6 +11,9 @@ where C: RaftTypeConfig
/// The RaftCore task is still running.
Running(JoinHandleOf<C, Result<Infallible, Fatal<C>>>),

/// The RaftCore task is waiting for a signal to finish joining.
Joining(WatchReceiverOf<C, bool>),

/// The RaftCore task has finished. The return value of the task is stored.
Done(Result<Infallible, Fatal<C>>),
}
Expand Down
4 changes: 2 additions & 2 deletions openraft/src/raft/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ where C: RaftTypeConfig
rx_data_metrics,
rx_server_metrics,
tx_shutdown: std::sync::Mutex::new(Some(tx_shutdown)),
core_state: Mutex::new(CoreState::Running(core_handle)),
core_state: std::sync::Mutex::new(CoreState::Running(core_handle)),

snapshot: Mutex::new(None),
};
Expand Down Expand Up @@ -828,7 +828,7 @@ where C: RaftTypeConfig
tracing::debug!("{} receives result is error: {:?}", func_name!(), recv_res.is_err());

let Ok(v) = recv_res else {
if self.inner.is_core_running().await {
if self.inner.is_core_running() {
return Ok(Err(InvalidStateMachineType::new::<SM>()));
} else {
let fatal = self.inner.get_core_stopped_error("receiving rx from RaftCore", None::<&'static str>).await;
Expand Down
70 changes: 56 additions & 14 deletions openraft/src/raft/raft_inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::Level;

use crate::async_runtime::watch::WatchReceiver;
use crate::async_runtime::watch::WatchSender;
use crate::async_runtime::MpscUnboundedSender;
use crate::config::RuntimeConfig;
use crate::core::raft_msg::external_command::ExternalCommand;
Expand All @@ -16,11 +18,13 @@ use crate::error::RaftError;
use crate::metrics::RaftDataMetrics;
use crate::metrics::RaftServerMetrics;
use crate::raft::core_state::CoreState;
use crate::type_config::alias::AsyncRuntimeOf;
use crate::type_config::alias::MpscUnboundedSenderOf;
use crate::type_config::alias::OneshotReceiverOf;
use crate::type_config::alias::OneshotSenderOf;
use crate::type_config::alias::WatchReceiverOf;
use crate::type_config::AsyncRuntime;
use crate::type_config::TypeConfigExt;
use crate::Config;
use crate::OptionalSend;
use crate::RaftMetrics;
Expand All @@ -41,7 +45,7 @@ where C: RaftTypeConfig
pub(in crate::raft) rx_server_metrics: WatchReceiverOf<C, RaftServerMetrics<C>>,

pub(in crate::raft) tx_shutdown: std::sync::Mutex<Option<OneshotSenderOf<C, ()>>>,
pub(in crate::raft) core_state: Mutex<CoreState<C>>,
pub(in crate::raft) core_state: std::sync::Mutex<CoreState<C>>,

/// The ongoing snapshot transmission.
pub(in crate::raft) snapshot: Mutex<Option<crate::network::snapshot_transport::Streaming<C>>>,
Expand Down Expand Up @@ -129,8 +133,8 @@ where C: RaftTypeConfig
Ok(())
}

pub(in crate::raft) async fn is_core_running(&self) -> bool {
let state = self.core_state.lock().await;
pub(in crate::raft) fn is_core_running(&self) -> bool {
let state = self.core_state.lock().unwrap();
state.is_running()
}

Expand All @@ -145,7 +149,7 @@ where C: RaftTypeConfig

// Retrieve the result.
let core_res = {
let state = self.core_state.lock().await;
let state = self.core_state.lock().unwrap();
if let CoreState::Done(core_task_res) = &*state {
core_task_res.clone()
} else {
Expand All @@ -170,15 +174,40 @@ where C: RaftTypeConfig
/// Wait for `RaftCore` task to finish and record the returned value from the task.
#[tracing::instrument(level = "debug", skip_all)]
pub(in crate::raft) async fn join_core_task(&self) {
let mut state = self.core_state.lock().await;
match &mut *state {
CoreState::Running(handle) => {
let res = handle.await;
tracing::info!(res = debug(&res), "RaftCore exited");
// Get the Running state of RaftCore,
// or an error if RaftCore has been in Joining state.
let running_res = {
let mut state = self.core_state.lock().unwrap();

match &*state {
CoreState::Running(_) => {
let (tx, rx) = C::watch_channel::<bool>(false);

let prev = std::mem::replace(&mut *state, CoreState::Joining(rx));

let CoreState::Running(join_handle) = prev else {
unreachable!()
};

Ok((join_handle, tx))
}
CoreState::Joining(watch_rx) => Err(watch_rx.clone()),
CoreState::Done(_) => {
// RaftCore has already finished exiting, nothing to do
return;
}
}
};

match running_res {
Ok((join_handle, tx)) => {
let join_res = join_handle.await;

let core_task_res = match res {
tracing::info!(res = debug(&join_res), "RaftCore exited");

let core_task_res = match join_res {
Err(err) => {
if C::AsyncRuntime::is_panic(&err) {
if AsyncRuntimeOf::<C>::is_panic(&err) {
Err(Fatal::Panicked)
} else {
Err(Fatal::Stopped)
Expand All @@ -187,10 +216,23 @@ where C: RaftTypeConfig
Ok(returned_res) => returned_res,
};

*state = CoreState::Done(core_task_res);
{
let mut state = self.core_state.lock().unwrap();
*state = CoreState::Done(core_task_res);
}
tx.send(true).ok();
}
CoreState::Done(_) => {
// RaftCore has already quit, nothing to do
Err(mut rx) => {
// Other thread is waiting for the core to finish.
loop {
let res = rx.changed().await;
if res.is_err() {
break;
}
if *rx.borrow_watched() {
break;
}
}
}
}
}
Expand Down

0 comments on commit 5c82dfd

Please sign in to comment.