diff --git a/openraft/src/core/raft_core.rs b/openraft/src/core/raft_core.rs index 9a0bdc72c..1c5f2fcfc 100644 --- a/openraft/src/core/raft_core.rs +++ b/openraft/src/core/raft_core.rs @@ -14,11 +14,11 @@ use futures::StreamExt; use futures::TryFutureExt; use maplit::btreeset; use tokio::select; -use tokio::sync::watch; use tracing::Instrument; use tracing::Level; use tracing::Span; +use crate::async_runtime::watch::WatchSender; use crate::async_runtime::MpscUnboundedSender; use crate::async_runtime::OneshotSender; use crate::async_runtime::TryRecvError; @@ -92,6 +92,7 @@ use crate::type_config::alias::MpscUnboundedReceiverOf; use crate::type_config::alias::MpscUnboundedSenderOf; use crate::type_config::alias::OneshotReceiverOf; use crate::type_config::alias::ResponderOf; +use crate::type_config::alias::WatchSenderOf; use crate::type_config::async_runtime::MpscUnboundedReceiver; use crate::type_config::TypeConfigExt; use crate::ChangeMembers; @@ -179,9 +180,9 @@ where /// A Receiver to receive callback from other components. pub(crate) rx_notify: MpscUnboundedReceiverOf>, - pub(crate) tx_metrics: watch::Sender>, - pub(crate) tx_data_metrics: watch::Sender>, - pub(crate) tx_server_metrics: watch::Sender>, + pub(crate) tx_metrics: WatchSenderOf>, + pub(crate) tx_data_metrics: WatchSenderOf>, + pub(crate) tx_server_metrics: WatchSenderOf>, pub(crate) command_state: CommandState, @@ -213,7 +214,7 @@ where tracing::debug!("update the metrics for shutdown"); { - let mut curr = self.tx_metrics.borrow().clone(); + let mut curr = self.tx_metrics.borrow_watched().clone(); curr.state = ServerState::Shutdown; curr.running_state = Err(err.clone()); diff --git a/openraft/src/metrics/wait.rs b/openraft/src/metrics/wait.rs index 343d93f74..57196381d 100644 --- a/openraft/src/metrics/wait.rs +++ b/openraft/src/metrics/wait.rs @@ -1,12 +1,12 @@ use core::time::Duration; use std::collections::BTreeSet; -use tokio::sync::watch; - +use crate::async_runtime::watch::WatchReceiver; use crate::core::ServerState; use crate::metrics::Condition; use crate::metrics::Metric; use crate::metrics::RaftMetrics; +use crate::type_config::alias::WatchReceiverOf; use crate::type_config::TypeConfigExt; use crate::LogId; use crate::OptionalSend; @@ -27,7 +27,7 @@ pub enum WaitError { /// some condition. pub struct Wait { pub timeout: Duration, - pub rx: watch::Receiver>, + pub rx: WatchReceiverOf>, } impl Wait @@ -41,7 +41,7 @@ where C: RaftTypeConfig let mut rx = self.rx.clone(); loop { - let latest = rx.borrow().clone(); + let latest = rx.borrow_watched().clone(); tracing::debug!("id={} wait {:} latest: {}", latest.id, msg.to_string(), latest); diff --git a/openraft/src/metrics/wait_test.rs b/openraft/src/metrics/wait_test.rs index 6713e0082..bccd80361 100644 --- a/openraft/src/metrics/wait_test.rs +++ b/openraft/src/metrics/wait_test.rs @@ -3,7 +3,6 @@ use std::time::Duration; use maplit::btreemap; use maplit::btreeset; -use tokio::sync::watch; use tokio::time::sleep; use crate::core::ServerState; @@ -13,6 +12,8 @@ use crate::metrics::Wait; use crate::metrics::WaitError; use crate::testing::log_id; use crate::type_config::alias::NodeIdOf; +use crate::type_config::alias::WatchSenderOf; +use crate::type_config::TypeConfigExt; use crate::vote::CommittedLeaderId; use crate::LogId; use crate::Membership; @@ -242,7 +243,7 @@ async fn test_wait_purged() -> anyhow::Result<()> { Ok(()) } -pub(crate) type InitResult = (RaftMetrics, Wait, watch::Sender>); +pub(crate) type InitResult = (RaftMetrics, Wait, WatchSenderOf>); /// Build a initial state for testing of Wait: /// Returns init metrics, Wait, and the tx to send an updated metrics. @@ -265,7 +266,7 @@ where C: RaftTypeConfig { snapshot: None, replication: None, }; - let (tx, rx) = watch::channel(init.clone()); + let (tx, rx) = C::watch_channel(init.clone()); let w = Wait { timeout: Duration::from_millis(100), rx, diff --git a/openraft/src/raft/mod.rs b/openraft/src/raft/mod.rs index 9a7f7392a..16715b709 100644 --- a/openraft/src/raft/mod.rs +++ b/openraft/src/raft/mod.rs @@ -40,12 +40,12 @@ pub use message::InstallSnapshotResponse; pub use message::SnapshotResponse; pub use message::VoteRequest; pub use message::VoteResponse; -use tokio::sync::watch; use tokio::sync::Mutex; use tracing::trace_span; use tracing::Instrument; use tracing::Level; +use crate::async_runtime::watch::WatchReceiver; use crate::async_runtime::MpscUnboundedSender; use crate::async_runtime::OneshotSender; use crate::config::Config; @@ -81,6 +81,7 @@ use crate::type_config::alias::JoinErrorOf; use crate::type_config::alias::ResponderOf; use crate::type_config::alias::ResponderReceiverOf; use crate::type_config::alias::SnapshotDataOf; +use crate::type_config::alias::WatchReceiverOf; use crate::type_config::TypeConfigExt; use crate::LogId; use crate::LogIdOptionExt; @@ -241,9 +242,9 @@ where C: RaftTypeConfig { let (tx_api, rx_api) = C::mpsc_unbounded(); let (tx_notify, rx_notify) = C::mpsc_unbounded(); - let (tx_metrics, rx_metrics) = watch::channel(RaftMetrics::new_initial(id)); - let (tx_data_metrics, rx_data_metrics) = watch::channel(RaftDataMetrics::default()); - let (tx_server_metrics, rx_server_metrics) = watch::channel(RaftServerMetrics::default()); + let (tx_metrics, rx_metrics) = C::watch_channel(RaftMetrics::new_initial(id)); + let (tx_data_metrics, rx_data_metrics) = C::watch_channel(RaftDataMetrics::default()); + let (tx_server_metrics, rx_server_metrics) = C::watch_channel(RaftServerMetrics::default()); let (tx_shutdown, rx_shutdown) = C::oneshot(); let tick_handle = Tick::spawn( @@ -478,7 +479,7 @@ where C: RaftTypeConfig /// reads. This method is perfect for making decisions on where to route client requests. #[tracing::instrument(level = "debug", skip(self))] pub async fn current_leader(&self) -> Option { - self.metrics().borrow().current_leader + self.metrics().borrow_watched().current_leader } /// Check to ensure this node is still the cluster leader, in order to guard against stale reads @@ -777,17 +778,17 @@ where C: RaftTypeConfig } /// Get a handle to the metrics channel. - pub fn metrics(&self) -> watch::Receiver> { + pub fn metrics(&self) -> WatchReceiverOf> { self.inner.rx_metrics.clone() } /// Get a handle to the data metrics channel. - pub fn data_metrics(&self) -> watch::Receiver> { + pub fn data_metrics(&self) -> WatchReceiverOf> { self.inner.rx_data_metrics.clone() } /// Get a handle to the server metrics channel. - pub fn server_metrics(&self) -> watch::Receiver> { + pub fn server_metrics(&self) -> WatchReceiverOf> { self.inner.rx_server_metrics.clone() } diff --git a/openraft/src/raft/raft_inner.rs b/openraft/src/raft/raft_inner.rs index 2fcec8173..6d3241932 100644 --- a/openraft/src/raft/raft_inner.rs +++ b/openraft/src/raft/raft_inner.rs @@ -3,7 +3,6 @@ use std::fmt::Debug; use std::future::Future; use std::sync::Arc; -use tokio::sync::watch; use tokio::sync::Mutex; use tracing::Level; @@ -20,6 +19,7 @@ use crate::raft::core_state::CoreState; use crate::type_config::alias::MpscUnboundedSenderOf; use crate::type_config::alias::OneshotReceiverOf; use crate::type_config::alias::OneshotSenderOf; +use crate::type_config::alias::WatchReceiverOf; use crate::type_config::AsyncRuntime; use crate::Config; use crate::OptionalSend; @@ -36,9 +36,9 @@ where C: RaftTypeConfig pub(in crate::raft) runtime_config: Arc, pub(in crate::raft) tick_handle: TickHandle, pub(in crate::raft) tx_api: MpscUnboundedSenderOf>, - pub(in crate::raft) rx_metrics: watch::Receiver>, - pub(in crate::raft) rx_data_metrics: watch::Receiver>, - pub(in crate::raft) rx_server_metrics: watch::Receiver>, + pub(in crate::raft) rx_metrics: WatchReceiverOf>, + pub(in crate::raft) rx_data_metrics: WatchReceiverOf>, + pub(in crate::raft) rx_server_metrics: WatchReceiverOf>, // TODO(xp): it does not need to be a async mutex. #[allow(clippy::type_complexity)] diff --git a/openraft/src/type_config.rs b/openraft/src/type_config.rs index a2a34ebea..bdd51ac1f 100644 --- a/openraft/src/type_config.rs +++ b/openraft/src/type_config.rs @@ -94,6 +94,7 @@ pub trait RaftTypeConfig: /// /// [`type-alias`]: crate::docs::feature_flags#feature-flag-type-alias pub mod alias { + use crate::async_runtime::watch; use crate::async_runtime::MpscUnbounded; use crate::raft::responder::Responder; use crate::type_config::AsyncRuntime; @@ -128,6 +129,10 @@ pub mod alias { pub type MpscUnboundedReceiverOf = as MpscUnbounded>::Receiver; pub type MpscUnboundedWeakSenderOf = as MpscUnbounded>::WeakSender; + pub type WatchOf = as AsyncRuntime>::Watch; + pub type WatchSenderOf = as watch::Watch>::Sender; + pub type WatchReceiverOf = as watch::Watch>::Receiver; + // Usually used types pub type LogIdOf = crate::LogId>; pub type VoteOf = crate::Vote>; diff --git a/openraft/src/type_config/async_runtime/impls/tokio_runtime.rs b/openraft/src/type_config/async_runtime/impls/tokio_runtime.rs index 58b10551f..fcfd5d25c 100644 --- a/openraft/src/type_config/async_runtime/impls/tokio_runtime.rs +++ b/openraft/src/type_config/async_runtime/impls/tokio_runtime.rs @@ -2,12 +2,15 @@ use std::future::Future; use std::time::Duration; use tokio::sync::mpsc; +use tokio::sync::watch as tokio_watch; use crate::async_runtime::mpsc_unbounded; use crate::async_runtime::mpsc_unbounded::MpscUnbounded; +use crate::async_runtime::watch; use crate::type_config::OneshotSender; use crate::AsyncRuntime; use crate::OptionalSend; +use crate::OptionalSync; use crate::TokioInstant; /// `Tokio` is the default asynchronous executor. @@ -80,6 +83,7 @@ impl AsyncRuntime for TokioRuntime { } type MpscUnbounded = TokioMpscUnbounded; + type Watch = TokioWatch; } impl OneshotSender for tokio::sync::oneshot::Sender { @@ -142,3 +146,45 @@ where T: OptionalSend self.upgrade() } } + +pub struct TokioWatch; + +impl watch::Watch for TokioWatch { + type Sender = tokio_watch::Sender; + type Receiver = tokio_watch::Receiver; + + type Ref<'a, T: OptionalSend + 'a> = tokio_watch::Ref<'a, T>; + + fn channel(init: T) -> (Self::Sender, Self::Receiver) { + tokio_watch::channel(init) + } +} + +impl watch::WatchSender for tokio_watch::Sender +where T: OptionalSend + OptionalSync +{ + fn send(&self, value: T) -> Result<(), watch::SendError> { + self.send(value).map_err(|e| watch::SendError(e.0)) + } + + fn send_if_modified(&self, modify: F) -> bool + where F: FnOnce(&mut T) -> bool { + self.send_if_modified(modify) + } + + fn borrow_watched(&self) -> ::Ref<'_, T> { + self.borrow() + } +} + +impl watch::WatchReceiver for tokio_watch::Receiver +where T: OptionalSend + OptionalSync +{ + async fn changed(&mut self) -> Result<(), watch::RecvError> { + self.changed().await.map_err(|_| watch::RecvError(())) + } + + fn borrow_watched(&self) -> ::Ref<'_, T> { + self.borrow() + } +} diff --git a/openraft/src/type_config/async_runtime/mod.rs b/openraft/src/type_config/async_runtime/mod.rs index db943e574..b61cd0498 100644 --- a/openraft/src/type_config/async_runtime/mod.rs +++ b/openraft/src/type_config/async_runtime/mod.rs @@ -10,6 +10,7 @@ pub(crate) mod impls { } pub mod mpsc_unbounded; mod oneshot; +pub mod watch; use std::fmt::Debug; use std::fmt::Display; @@ -23,6 +24,7 @@ pub use mpsc_unbounded::MpscUnboundedWeakSender; pub use mpsc_unbounded::SendError; pub use mpsc_unbounded::TryRecvError; pub use oneshot::OneshotSender; +pub use watch::Watch; use crate::Instant; use crate::OptionalSend; @@ -116,4 +118,6 @@ pub trait AsyncRuntime: Debug + Default + PartialEq + Eq + OptionalSend + Option where T: OptionalSend; type MpscUnbounded: MpscUnbounded; + + type Watch: Watch; } diff --git a/openraft/src/type_config/async_runtime/watch/mod.rs b/openraft/src/type_config/async_runtime/watch/mod.rs new file mode 100644 index 000000000..84ee0b91a --- /dev/null +++ b/openraft/src/type_config/async_runtime/watch/mod.rs @@ -0,0 +1,39 @@ +mod watch_error; + +use openraft_macros::add_async_trait; +pub use watch_error::RecvError; +pub use watch_error::SendError; + +use crate::OptionalSend; +use crate::OptionalSync; + +pub trait Watch: Sized + OptionalSend { + type Sender: WatchSender; + type Receiver: WatchReceiver; + + type Ref<'a, T: OptionalSend + 'a>: std::ops::Deref + 'a; + + fn channel(init: T) -> (Self::Sender, Self::Receiver); +} + +pub trait WatchSender: OptionalSend + Clone +where + W: Watch, + T: OptionalSend + OptionalSync, +{ + fn send(&self, value: T) -> Result<(), SendError>; + fn send_if_modified(&self, modify: F) -> bool + where F: FnOnce(&mut T) -> bool; + + fn borrow_watched(&self) -> W::Ref<'_, T>; +} + +#[add_async_trait] +pub trait WatchReceiver: OptionalSend + OptionalSync + Clone +where + W: Watch, + T: OptionalSend + OptionalSync, +{ + async fn changed(&mut self) -> Result<(), RecvError>; + fn borrow_watched(&self) -> W::Ref<'_, T>; +} diff --git a/openraft/src/type_config/async_runtime/watch/watch_error.rs b/openraft/src/type_config/async_runtime/watch/watch_error.rs new file mode 100644 index 000000000..bb766dcf8 --- /dev/null +++ b/openraft/src/type_config/async_runtime/watch/watch_error.rs @@ -0,0 +1,37 @@ +use std::fmt; + +/// Error returned by the `WatchSender`. +#[derive(PartialEq, Eq, Clone, Copy)] +pub struct SendError(pub T); + +impl fmt::Debug for SendError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SendError").finish_non_exhaustive() + } +} + +impl fmt::Display for SendError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "watch channel closed") + } +} + +impl std::error::Error for SendError {} + +/// Error returned by the `WatchReceiver`. +#[derive(PartialEq, Eq, Clone, Copy)] +pub struct RecvError(pub ()); + +impl fmt::Debug for RecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RecvError").finish_non_exhaustive() + } +} + +impl fmt::Display for RecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "watch channel closed") + } +} + +impl std::error::Error for RecvError {} diff --git a/openraft/src/type_config/util.rs b/openraft/src/type_config/util.rs index eca62d7f6..932503929 100644 --- a/openraft/src/type_config/util.rs +++ b/openraft/src/type_config/util.rs @@ -3,6 +3,7 @@ use std::time::Duration; use openraft_macros::since; +use crate::async_runtime::watch::Watch; use crate::async_runtime::MpscUnbounded; use crate::type_config::alias::AsyncRuntimeOf; use crate::type_config::alias::InstantOf; @@ -14,9 +15,13 @@ use crate::type_config::alias::OneshotReceiverOf; use crate::type_config::alias::OneshotSenderOf; use crate::type_config::alias::SleepOf; use crate::type_config::alias::TimeoutOf; +use crate::type_config::alias::WatchOf; +use crate::type_config::alias::WatchReceiverOf; +use crate::type_config::alias::WatchSenderOf; use crate::type_config::AsyncRuntime; use crate::Instant; use crate::OptionalSend; +use crate::OptionalSync; use crate::RaftTypeConfig; /// Collection of utility methods to `RaftTypeConfig` function. @@ -73,6 +78,16 @@ pub trait TypeConfigExt: RaftTypeConfig { MpscUnboundedOf::::channel() } + /// Creates a watch channel for watching for changes to a value from multiple + /// points in the code base. + /// + /// This is just a wrapper of + /// [`AsyncRuntime::Watch::channel()`](`crate::async_runtime::Watch::channel`). + fn watch_channel(init: T) -> (WatchSenderOf, WatchReceiverOf) + where T: OptionalSend + OptionalSync { + WatchOf::::channel(init) + } + // Task methods /// Spawn a new task.