Skip to content

Commit

Permalink
Merge pull request databendlabs#995 from schreter/thread_rng_in_runtime
Browse files Browse the repository at this point in the history
Feature: Add random number generator to `AsyncRuntime`
  • Loading branch information
drmingdrmer authored Jan 19, 2024
2 parents deb51db + 9aefad2 commit 49e8adf
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 6 deletions.
19 changes: 19 additions & 0 deletions openraft/src/async_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use crate::TokioInstant;
/// The intention of this trait is to allow an application using this crate to bind an asynchronous
/// runtime that suits it the best.
///
/// Some additional related functions are also exposed by this trait.
///
/// ## Note
///
/// The default asynchronous runtime is `tokio`.
Expand All @@ -39,6 +41,9 @@ pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static
/// to await the outcome of a [`Future`].
type Timeout<R, T: Future<Output = R> + OptionalSend>: Future<Output = Result<R, Self::TimeoutError>> + OptionalSend;

/// Type of a thread-local random number generator.
type ThreadLocalRng: rand::Rng;

/// Spawn a new task.
fn spawn<T>(future: T) -> Self::JoinHandle<T::Output>
where
Expand All @@ -62,6 +67,14 @@ pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static

/// Abort the task associated with the supplied join handle.
fn abort<T: OptionalSend + 'static>(join_handle: &Self::JoinHandle<T>);

/// Get the random number generator to use for generating random numbers.
///
/// # Note
///
/// This is a per-thread instance, which cannot be shared across threads or
/// sent to another thread.
fn thread_rng() -> Self::ThreadLocalRng;
}

/// `Tokio` is the default asynchronous executor.
Expand All @@ -75,6 +88,7 @@ impl AsyncRuntime for TokioRuntime {
type Instant = TokioInstant;
type TimeoutError = tokio::time::error::Elapsed;
type Timeout<R, T: Future<Output = R> + OptionalSend> = tokio::time::Timeout<T>;
type ThreadLocalRng = rand::rngs::ThreadRng;

#[inline]
fn spawn<T>(future: T) -> Self::JoinHandle<T::Output>
Expand Down Expand Up @@ -121,4 +135,9 @@ impl AsyncRuntime for TokioRuntime {
fn abort<T: OptionalSend + 'static>(join_handle: &Self::JoinHandle<T>) {
join_handle.abort();
}

#[inline]
fn thread_rng() -> Self::ThreadLocalRng {
rand::thread_rng()
}
}
6 changes: 3 additions & 3 deletions openraft/src/config/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ use std::time::Duration;

use anyerror::AnyError;
use clap::Parser;
use rand::thread_rng;
use rand::Rng;

use crate::config::error::ConfigError;
use crate::raft_state::LogStateReader;
use crate::AsyncRuntime;
use crate::LogIdOptionExt;
use crate::NodeId;

Expand Down Expand Up @@ -248,8 +248,8 @@ impl Default for Config {

impl Config {
/// Generate a new random election timeout within the configured min & max.
pub fn new_rand_election_timeout(&self) -> u64 {
thread_rng().gen_range(self.election_timeout_min..self.election_timeout_max)
pub fn new_rand_election_timeout<RT: AsyncRuntime>(&self) -> u64 {
RT::thread_rng().gen_range(self.election_timeout_min..self.election_timeout_max)
}

/// Get the timeout for sending and installing the last snapshot segment.
Expand Down
5 changes: 3 additions & 2 deletions openraft/src/engine/engine_config.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::time::Duration;

use crate::engine::time_state;
use crate::AsyncRuntime;
use crate::Config;
use crate::NodeId;
use crate::SnapshotPolicy;
Expand Down Expand Up @@ -41,8 +42,8 @@ impl<NID: NodeId> Default for EngineConfig<NID> {
}

impl<NID: NodeId> EngineConfig<NID> {
pub(crate) fn new(id: NID, config: &Config) -> Self {
let election_timeout = Duration::from_millis(config.new_rand_election_timeout());
pub(crate) fn new<RT: AsyncRuntime>(id: NID, config: &Config) -> Self {
let election_timeout = Duration::from_millis(config.new_rand_election_timeout::<RT>());
Self {
id,
snapshot_policy: config.snapshot_policy.clone(),
Expand Down
2 changes: 1 addition & 1 deletion openraft/src/raft/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ where C: RaftTypeConfig
cluster = display(&config.cluster_name)
);

let eng_config = EngineConfig::new(id, config.as_ref());
let eng_config = EngineConfig::new::<C::AsyncRuntime>(id, config.as_ref());

let state = {
let mut helper = StorageHelper::new(&mut log_store, &mut state_machine);
Expand Down

0 comments on commit 49e8adf

Please sign in to comment.