Skip to content

Commit

Permalink
Refactor: Abstract the watch channel into a traits (databendlabs#1167)
Browse files Browse the repository at this point in the history
* Refactor: Abstract the watch channel into a traits
  • Loading branch information
ariesdevil committed Jul 10, 2024
1 parent e69aea1 commit 1b4ec53
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 24 deletions.
11 changes: 6 additions & 5 deletions openraft/src/core/raft_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ use futures::StreamExt;
use futures::TryFutureExt;
use maplit::btreeset;
use tokio::select;
use tokio::sync::watch;
use tracing::Instrument;
use tracing::Level;
use tracing::Span;

use crate::async_runtime::watch::WatchSender;
use crate::async_runtime::MpscUnboundedSender;
use crate::async_runtime::OneshotSender;
use crate::async_runtime::TryRecvError;
Expand Down Expand Up @@ -92,6 +92,7 @@ use crate::type_config::alias::MpscUnboundedReceiverOf;
use crate::type_config::alias::MpscUnboundedSenderOf;
use crate::type_config::alias::OneshotReceiverOf;
use crate::type_config::alias::ResponderOf;
use crate::type_config::alias::WatchSenderOf;
use crate::type_config::async_runtime::MpscUnboundedReceiver;
use crate::type_config::TypeConfigExt;
use crate::ChangeMembers;
Expand Down Expand Up @@ -179,9 +180,9 @@ where
/// A Receiver to receive callback from other components.
pub(crate) rx_notify: MpscUnboundedReceiverOf<C, Notify<C>>,

pub(crate) tx_metrics: watch::Sender<RaftMetrics<C>>,
pub(crate) tx_data_metrics: watch::Sender<RaftDataMetrics<C>>,
pub(crate) tx_server_metrics: watch::Sender<RaftServerMetrics<C>>,
pub(crate) tx_metrics: WatchSenderOf<C, RaftMetrics<C>>,
pub(crate) tx_data_metrics: WatchSenderOf<C, RaftDataMetrics<C>>,
pub(crate) tx_server_metrics: WatchSenderOf<C, RaftServerMetrics<C>>,

pub(crate) command_state: CommandState,

Expand Down Expand Up @@ -213,7 +214,7 @@ where

tracing::debug!("update the metrics for shutdown");
{
let mut curr = self.tx_metrics.borrow().clone();
let mut curr = self.tx_metrics.borrow_watched().clone();
curr.state = ServerState::Shutdown;
curr.running_state = Err(err.clone());

Expand Down
8 changes: 4 additions & 4 deletions openraft/src/metrics/wait.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use core::time::Duration;
use std::collections::BTreeSet;

use tokio::sync::watch;

use crate::async_runtime::watch::WatchReceiver;
use crate::core::ServerState;
use crate::metrics::Condition;
use crate::metrics::Metric;
use crate::metrics::RaftMetrics;
use crate::type_config::alias::WatchReceiverOf;
use crate::type_config::TypeConfigExt;
use crate::LogId;
use crate::OptionalSend;
Expand All @@ -27,7 +27,7 @@ pub enum WaitError {
/// some condition.
pub struct Wait<C: RaftTypeConfig> {
pub timeout: Duration,
pub rx: watch::Receiver<RaftMetrics<C>>,
pub rx: WatchReceiverOf<C, RaftMetrics<C>>,
}

impl<C> Wait<C>
Expand All @@ -41,7 +41,7 @@ where C: RaftTypeConfig

let mut rx = self.rx.clone();
loop {
let latest = rx.borrow().clone();
let latest = rx.borrow_watched().clone();

tracing::debug!("id={} wait {:} latest: {}", latest.id, msg.to_string(), latest);

Expand Down
7 changes: 4 additions & 3 deletions openraft/src/metrics/wait_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::time::Duration;

use maplit::btreemap;
use maplit::btreeset;
use tokio::sync::watch;
use tokio::time::sleep;

use crate::core::ServerState;
Expand All @@ -13,6 +12,8 @@ use crate::metrics::Wait;
use crate::metrics::WaitError;
use crate::testing::log_id;
use crate::type_config::alias::NodeIdOf;
use crate::type_config::alias::WatchSenderOf;
use crate::type_config::TypeConfigExt;
use crate::vote::CommittedLeaderId;
use crate::LogId;
use crate::Membership;
Expand Down Expand Up @@ -242,7 +243,7 @@ async fn test_wait_purged() -> anyhow::Result<()> {
Ok(())
}

pub(crate) type InitResult<C> = (RaftMetrics<C>, Wait<C>, watch::Sender<RaftMetrics<C>>);
pub(crate) type InitResult<C> = (RaftMetrics<C>, Wait<C>, WatchSenderOf<C, RaftMetrics<C>>);

/// Build a initial state for testing of Wait:
/// Returns init metrics, Wait, and the tx to send an updated metrics.
Expand All @@ -265,7 +266,7 @@ where C: RaftTypeConfig {
snapshot: None,
replication: None,
};
let (tx, rx) = watch::channel(init.clone());
let (tx, rx) = C::watch_channel(init.clone());
let w = Wait {
timeout: Duration::from_millis(100),
rx,
Expand Down
17 changes: 9 additions & 8 deletions openraft/src/raft/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ pub use message::InstallSnapshotResponse;
pub use message::SnapshotResponse;
pub use message::VoteRequest;
pub use message::VoteResponse;
use tokio::sync::watch;
use tokio::sync::Mutex;
use tracing::trace_span;
use tracing::Instrument;
use tracing::Level;

use crate::async_runtime::watch::WatchReceiver;
use crate::async_runtime::MpscUnboundedSender;
use crate::async_runtime::OneshotSender;
use crate::config::Config;
Expand Down Expand Up @@ -81,6 +81,7 @@ use crate::type_config::alias::JoinErrorOf;
use crate::type_config::alias::ResponderOf;
use crate::type_config::alias::ResponderReceiverOf;
use crate::type_config::alias::SnapshotDataOf;
use crate::type_config::alias::WatchReceiverOf;
use crate::type_config::TypeConfigExt;
use crate::LogId;
use crate::LogIdOptionExt;
Expand Down Expand Up @@ -241,9 +242,9 @@ where C: RaftTypeConfig
{
let (tx_api, rx_api) = C::mpsc_unbounded();
let (tx_notify, rx_notify) = C::mpsc_unbounded();
let (tx_metrics, rx_metrics) = watch::channel(RaftMetrics::new_initial(id));
let (tx_data_metrics, rx_data_metrics) = watch::channel(RaftDataMetrics::default());
let (tx_server_metrics, rx_server_metrics) = watch::channel(RaftServerMetrics::default());
let (tx_metrics, rx_metrics) = C::watch_channel(RaftMetrics::new_initial(id));
let (tx_data_metrics, rx_data_metrics) = C::watch_channel(RaftDataMetrics::default());
let (tx_server_metrics, rx_server_metrics) = C::watch_channel(RaftServerMetrics::default());
let (tx_shutdown, rx_shutdown) = C::oneshot();

let tick_handle = Tick::spawn(
Expand Down Expand Up @@ -478,7 +479,7 @@ where C: RaftTypeConfig
/// reads. This method is perfect for making decisions on where to route client requests.
#[tracing::instrument(level = "debug", skip(self))]
pub async fn current_leader(&self) -> Option<C::NodeId> {
self.metrics().borrow().current_leader
self.metrics().borrow_watched().current_leader
}

/// Check to ensure this node is still the cluster leader, in order to guard against stale reads
Expand Down Expand Up @@ -777,17 +778,17 @@ where C: RaftTypeConfig
}

/// Get a handle to the metrics channel.
pub fn metrics(&self) -> watch::Receiver<RaftMetrics<C>> {
pub fn metrics(&self) -> WatchReceiverOf<C, RaftMetrics<C>> {
self.inner.rx_metrics.clone()
}

/// Get a handle to the data metrics channel.
pub fn data_metrics(&self) -> watch::Receiver<RaftDataMetrics<C>> {
pub fn data_metrics(&self) -> WatchReceiverOf<C, RaftDataMetrics<C>> {
self.inner.rx_data_metrics.clone()
}

/// Get a handle to the server metrics channel.
pub fn server_metrics(&self) -> watch::Receiver<RaftServerMetrics<C>> {
pub fn server_metrics(&self) -> WatchReceiverOf<C, RaftServerMetrics<C>> {
self.inner.rx_server_metrics.clone()
}

Expand Down
8 changes: 4 additions & 4 deletions openraft/src/raft/raft_inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::fmt::Debug;
use std::future::Future;
use std::sync::Arc;

use tokio::sync::watch;
use tokio::sync::Mutex;
use tracing::Level;

Expand All @@ -20,6 +19,7 @@ use crate::raft::core_state::CoreState;
use crate::type_config::alias::MpscUnboundedSenderOf;
use crate::type_config::alias::OneshotReceiverOf;
use crate::type_config::alias::OneshotSenderOf;
use crate::type_config::alias::WatchReceiverOf;
use crate::type_config::AsyncRuntime;
use crate::Config;
use crate::OptionalSend;
Expand All @@ -36,9 +36,9 @@ where C: RaftTypeConfig
pub(in crate::raft) runtime_config: Arc<RuntimeConfig>,
pub(in crate::raft) tick_handle: TickHandle<C>,
pub(in crate::raft) tx_api: MpscUnboundedSenderOf<C, RaftMsg<C>>,
pub(in crate::raft) rx_metrics: watch::Receiver<RaftMetrics<C>>,
pub(in crate::raft) rx_data_metrics: watch::Receiver<RaftDataMetrics<C>>,
pub(in crate::raft) rx_server_metrics: watch::Receiver<RaftServerMetrics<C>>,
pub(in crate::raft) rx_metrics: WatchReceiverOf<C, RaftMetrics<C>>,
pub(in crate::raft) rx_data_metrics: WatchReceiverOf<C, RaftDataMetrics<C>>,
pub(in crate::raft) rx_server_metrics: WatchReceiverOf<C, RaftServerMetrics<C>>,

// TODO(xp): it does not need to be a async mutex.
#[allow(clippy::type_complexity)]
Expand Down
5 changes: 5 additions & 0 deletions openraft/src/type_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ pub trait RaftTypeConfig:
///
/// [`type-alias`]: crate::docs::feature_flags#feature-flag-type-alias
pub mod alias {
use crate::async_runtime::watch;
use crate::async_runtime::MpscUnbounded;
use crate::raft::responder::Responder;
use crate::type_config::AsyncRuntime;
Expand Down Expand Up @@ -128,6 +129,10 @@ pub mod alias {
pub type MpscUnboundedReceiverOf<C, T> = <Mpsc<C> as MpscUnbounded>::Receiver<T>;
pub type MpscUnboundedWeakSenderOf<C, T> = <Mpsc<C> as MpscUnbounded>::WeakSender<T>;

pub type WatchOf<C> = <Rt<C> as AsyncRuntime>::Watch;
pub type WatchSenderOf<C, T> = <WatchOf<C> as watch::Watch>::Sender<T>;
pub type WatchReceiverOf<C, T> = <WatchOf<C> as watch::Watch>::Receiver<T>;

// Usually used types
pub type LogIdOf<C> = crate::LogId<NodeIdOf<C>>;
pub type VoteOf<C> = crate::Vote<NodeIdOf<C>>;
Expand Down
46 changes: 46 additions & 0 deletions openraft/src/type_config/async_runtime/impls/tokio_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ use std::future::Future;
use std::time::Duration;

use tokio::sync::mpsc;
use tokio::sync::watch as tokio_watch;

use crate::async_runtime::mpsc_unbounded;
use crate::async_runtime::mpsc_unbounded::MpscUnbounded;
use crate::async_runtime::watch;
use crate::type_config::OneshotSender;
use crate::AsyncRuntime;
use crate::OptionalSend;
use crate::OptionalSync;
use crate::TokioInstant;

/// `Tokio` is the default asynchronous executor.
Expand Down Expand Up @@ -80,6 +83,7 @@ impl AsyncRuntime for TokioRuntime {
}

type MpscUnbounded = TokioMpscUnbounded;
type Watch = TokioWatch;
}

impl<T> OneshotSender<T> for tokio::sync::oneshot::Sender<T> {
Expand Down Expand Up @@ -142,3 +146,45 @@ where T: OptionalSend
self.upgrade()
}
}

pub struct TokioWatch;

impl watch::Watch for TokioWatch {
type Sender<T: OptionalSend + OptionalSync> = tokio_watch::Sender<T>;
type Receiver<T: OptionalSend + OptionalSync> = tokio_watch::Receiver<T>;

type Ref<'a, T: OptionalSend + 'a> = tokio_watch::Ref<'a, T>;

fn channel<T: OptionalSend + OptionalSync>(init: T) -> (Self::Sender<T>, Self::Receiver<T>) {
tokio_watch::channel(init)
}
}

impl<T> watch::WatchSender<TokioWatch, T> for tokio_watch::Sender<T>
where T: OptionalSend + OptionalSync
{
fn send(&self, value: T) -> Result<(), watch::SendError<T>> {
self.send(value).map_err(|e| watch::SendError(e.0))
}

fn send_if_modified<F>(&self, modify: F) -> bool
where F: FnOnce(&mut T) -> bool {
self.send_if_modified(modify)
}

fn borrow_watched(&self) -> <TokioWatch as watch::Watch>::Ref<'_, T> {
self.borrow()
}
}

impl<T> watch::WatchReceiver<TokioWatch, T> for tokio_watch::Receiver<T>
where T: OptionalSend + OptionalSync
{
async fn changed(&mut self) -> Result<(), watch::RecvError> {
self.changed().await.map_err(|_| watch::RecvError(()))
}

fn borrow_watched(&self) -> <TokioWatch as watch::Watch>::Ref<'_, T> {
self.borrow()
}
}
4 changes: 4 additions & 0 deletions openraft/src/type_config/async_runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub(crate) mod impls {
}
pub mod mpsc_unbounded;
mod oneshot;
pub mod watch;

use std::fmt::Debug;
use std::fmt::Display;
Expand All @@ -23,6 +24,7 @@ pub use mpsc_unbounded::MpscUnboundedWeakSender;
pub use mpsc_unbounded::SendError;
pub use mpsc_unbounded::TryRecvError;
pub use oneshot::OneshotSender;
pub use watch::Watch;

use crate::Instant;
use crate::OptionalSend;
Expand Down Expand Up @@ -116,4 +118,6 @@ pub trait AsyncRuntime: Debug + Default + PartialEq + Eq + OptionalSend + Option
where T: OptionalSend;

type MpscUnbounded: MpscUnbounded;

type Watch: Watch;
}
39 changes: 39 additions & 0 deletions openraft/src/type_config/async_runtime/watch/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
mod watch_error;

use openraft_macros::add_async_trait;
pub use watch_error::RecvError;
pub use watch_error::SendError;

use crate::OptionalSend;
use crate::OptionalSync;

pub trait Watch: Sized + OptionalSend {
type Sender<T: OptionalSend + OptionalSync>: WatchSender<Self, T>;
type Receiver<T: OptionalSend + OptionalSync>: WatchReceiver<Self, T>;

type Ref<'a, T: OptionalSend + 'a>: std::ops::Deref<Target = T> + 'a;

fn channel<T: OptionalSend + OptionalSync>(init: T) -> (Self::Sender<T>, Self::Receiver<T>);
}

pub trait WatchSender<W, T>: OptionalSend + Clone
where
W: Watch,
T: OptionalSend + OptionalSync,
{
fn send(&self, value: T) -> Result<(), SendError<T>>;
fn send_if_modified<F>(&self, modify: F) -> bool
where F: FnOnce(&mut T) -> bool;

fn borrow_watched(&self) -> W::Ref<'_, T>;
}

#[add_async_trait]
pub trait WatchReceiver<W, T>: OptionalSend + OptionalSync + Clone
where
W: Watch,
T: OptionalSend + OptionalSync,
{
async fn changed(&mut self) -> Result<(), RecvError>;
fn borrow_watched(&self) -> W::Ref<'_, T>;
}
37 changes: 37 additions & 0 deletions openraft/src/type_config/async_runtime/watch/watch_error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use std::fmt;

/// Error returned by the `WatchSender`.
#[derive(PartialEq, Eq, Clone, Copy)]
pub struct SendError<T>(pub T);

impl<T> fmt::Debug for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SendError").finish_non_exhaustive()
}
}

impl<T> fmt::Display for SendError<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "watch channel closed")
}
}

impl<T> std::error::Error for SendError<T> {}

/// Error returned by the `WatchReceiver`.
#[derive(PartialEq, Eq, Clone, Copy)]
pub struct RecvError(pub ());

impl fmt::Debug for RecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RecvError").finish_non_exhaustive()
}
}

impl fmt::Display for RecvError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "watch channel closed")
}
}

impl std::error::Error for RecvError {}
Loading

0 comments on commit 1b4ec53

Please sign in to comment.