diff --git a/proto/raft/raft.proto b/proto/raft/raft.proto index ffff597..1676562 100644 --- a/proto/raft/raft.proto +++ b/proto/raft/raft.proto @@ -38,8 +38,8 @@ message AppendEntriesArgs { message AppendEntriesReply { uint64 term = 1; bool success = 2; - uint64 conflict_index = 3; - uint64 conflict_term = 4; + uint64 last_log_index = 3; + uint64 last_log_term = 4; } message InstallSnapshotArgs { diff --git a/src/bin/radis/main.rs b/src/bin/radis/main.rs index d684c24..5c9eaed 100644 --- a/src/bin/radis/main.rs +++ b/src/bin/radis/main.rs @@ -2,6 +2,7 @@ use clap::Parser; use radis::conf::Config; use radis::raft::RaftService; use tokio; +use tokio::sync::mpsc; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] @@ -18,7 +19,8 @@ async fn main() -> Result<(), Box> { let args = Args::parse(); let cfg = Config::from_path(&args.conf)?; - RaftService::new(cfg).serve().await?; + let (commit_tx, _) = mpsc::channel(1); + RaftService::new(cfg, commit_tx).serve().await?; Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 7299e74..e7a3529 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ pub mod raft { mod config; mod context; + mod log; mod service; pub mod state; diff --git a/src/raft/context.rs b/src/raft/context.rs index 1b866c1..798568a 100644 --- a/src/raft/context.rs +++ b/src/raft/context.rs @@ -1,12 +1,13 @@ use super::config::REQUEST_TIMEOUT; +use super::log::LogManager; use super::service::PeerClient; use crate::conf::Config; use crate::timer::{OneshotTimer, PeriodicTimer}; -use log::debug; +use log::{debug, info}; use std::sync::Arc; use std::time::Duration; -use tokio::sync::mpsc::Sender; -use tokio::sync::Mutex; +use tokio::sync::mpsc::{self, Sender}; +use tokio::sync::{Mutex, RwLock}; pub type PeerID = String; pub type Peer = usize; @@ -16,18 +17,29 @@ pub struct Context { id: String, peers: Vec>>, + log: LogManager, + peer_next_index: Vec>>, + peer_sync_index: Vec>>, + commit_ch: mpsc::Sender>>, + timeout: Arc, tick: Arc, } impl Context { - pub fn new(cfg: Config, timeout_event: Sender<()>, tick_event: Sender<()>) -> Self { + pub fn new( + cfg: Config, + commit_ch: mpsc::Sender>>, + timeout_event: Sender<()>, + tick_event: Sender<()>, + ) -> Self { let timeout = Duration::from_millis(REQUEST_TIMEOUT); let Config { id, listen_addr: _, peer_addrs, } = cfg; + let peers = peer_addrs.len(); Self { id, @@ -44,6 +56,12 @@ impl Context { Arc::new(Mutex::new(PeerClient::new(addr, timeout))) }) .collect(), + + log: LogManager::new(), + peer_next_index: (0..peers).map(|_| Arc::new(Mutex::new(0))).collect(), + peer_sync_index: (0..peers).map(|_| Arc::new(RwLock::new(0))).collect(), + commit_ch, + timeout: Arc::new(OneshotTimer::new(timeout_event)), tick: Arc::new(PeriodicTimer::new(tick_event)), } @@ -75,7 +93,7 @@ impl Context { } pub fn majority(&self) -> usize { - self.peers.len() / 2 + 1 + self.peers.len() / 2 } pub async fn reset_timeout(&self, timeout: Duration) { @@ -93,4 +111,37 @@ impl Context { pub async fn stop_tick(&self) { self.tick.stop().await; } + + pub fn log(&self) -> &LogManager { + &self.log + } + + pub fn log_mut(&mut self) -> &mut LogManager { + &mut self.log + } + + pub async fn commit_log(&mut self, index: LogIndex) { + self.log.commit(index, &self.commit_ch).await; + } + + pub fn peer_next_index(&self, peer: Peer) -> Arc> { + self.peer_next_index[peer].clone() + } + + pub async fn update_peer_index(&mut self, peer: Peer, index: LogIndex) { + *self.peer_sync_index[peer].write().await = index; + + let mut sync_indexes = vec![0; self.peers()]; + for (i, index) in self.peer_sync_index.iter().enumerate() { + sync_indexes[i] = *index.read().await; + } + info!("peer_sync_index: {:?}", sync_indexes); + self.commit_log(majority_index(sync_indexes)).await; + } +} + +fn majority_index(mut indexes: Vec) -> LogIndex { + indexes.sort_unstable(); + let majority_index = indexes.len() / 2 - 1; + indexes[majority_index] } diff --git a/src/raft/log.rs b/src/raft/log.rs new file mode 100644 index 0000000..ff72052 --- /dev/null +++ b/src/raft/log.rs @@ -0,0 +1,130 @@ +use log::info; +use tokio::sync::mpsc; + +use super::context::LogIndex; +use super::state::Term; +use super::Log; +use std::ops::Deref; +use std::sync::Arc; + +#[derive(Debug)] +pub struct InnerLog { + term: Term, + data: Arc>, +} + +impl InnerLog { + pub fn new(term: Term, data: Vec) -> Self { + Self { + term, + data: Arc::new(data), + } + } +} + +pub struct LogManager { + commit_index: LogIndex, + snapshot_index: LogIndex, + logs: Vec, + #[allow(dead_code)] + snapshot: Option>, +} + +impl LogManager { + pub fn new() -> Self { + Self { + commit_index: 0, + snapshot_index: 0, + logs: vec![InnerLog::new(0, vec![])], + snapshot: None, + } + } + + pub fn append(&mut self, term: Term, log: Vec) -> LogIndex { + self.logs.push(InnerLog::new(term, log)); + self.snapshot_index + self.logs.len() as u64 + } + + pub fn since(&self, start: usize) -> Vec { + let offset = start + self.snapshot_index as usize; + self.logs[start - self.snapshot_index as usize..] + .iter() + .enumerate() + .map(|(i, log)| Log { + index: (i + offset) as u64, + term: log.term, + command: log.data.deref().clone(), + }) + .collect() + } + + pub fn latest(&self) -> Option<(LogIndex, Term)> { + self.logs + .last() + .map(|log| (self.snapshot_index + self.logs.len() as u64 - 1, log.term)) + } + + pub fn term(&self, index: LogIndex) -> Option { + if index < self.snapshot_index { + return None; + } + let index = (index - self.snapshot_index) as usize; + if index >= self.logs.len() { + return None; + } + Some(self.logs[index].term) + } + + pub fn first_log_at_term(&self, term: Term) -> Option { + self.logs + .iter() + .position(|log| log.term == term) + .map(|index| index as u64 + self.snapshot_index) + } + + pub fn delete_since(&mut self, index: LogIndex) -> usize { + if index < self.snapshot_index { + return 0; + } + let index = (index - self.snapshot_index) as usize; + let deleted = self.logs.drain(index..).count(); + info!(target: "raft::log", + deleted = deleted; + "delete logs[{}..]", index + ); + deleted + } + + pub fn commit_index(&self) -> LogIndex { + self.commit_index + } + + pub fn snapshot_index(&self) -> LogIndex { + self.snapshot_index + } + + pub async fn commit(&mut self, index: LogIndex, ch: &mpsc::Sender>>) { + if index <= self.commit_index { + return; + } + + info!(target: "raft::log", + "commit logs[{}..{}]", + self.commit_index, index + ); + // Since commit_index is the index that already committed, + // we need to start from commit_index + 1 + let start = (self.commit_index - self.snapshot_index + 1) as usize; + let end = (index - self.snapshot_index + 1) as usize; + for log in if end >= self.logs.len() { + &self.logs[start..] + } else { + &self.logs[start..end] + } + .iter() + { + ch.send(log.data.clone()).await.unwrap(); + } + self.commit_index = index; + } +} diff --git a/src/raft/service.rs b/src/raft/service.rs index 9de1de1..104aec9 100644 --- a/src/raft/service.rs +++ b/src/raft/service.rs @@ -10,7 +10,7 @@ use anyhow::Result; use log::{info, trace}; use std::sync::Arc; use std::time::Duration; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::{mpsc, Mutex, RwLock}; use tonic::transport::{Channel, Endpoint, Server}; use tonic::{Request, Response, Status}; @@ -20,19 +20,22 @@ pub struct RaftService { listen_addr: String, context: Arc>, state: Arc>>>, + + close_ch: Arc>>>, } impl RaftService { - pub fn new(cfg: Config) -> Self { + pub fn new(cfg: Config, commit_ch: mpsc::Sender>>) -> Self { let Config { id, listen_addr, .. } = cfg.clone(); - let (context, state) = state::init(cfg); + let (context, state) = state::init(cfg, commit_ch); RaftService { id, listen_addr, context, state, + close_ch: Arc::new(Mutex::new(None)), } } @@ -44,27 +47,38 @@ impl RaftService { self.state.clone() } + pub async fn append_command(&self, cmd: Vec) -> Result<()> { + info!(target: "raft::service", id = self.id; "append command"); + let state = self.state.lock().await; + let new_state = state.on_command(self.context.clone(), cmd).await?; + if state::transition(state, new_state, self.context.clone()).await { + Err(anyhow::anyhow!("not leader")) + } else { + Ok(()) + } + } + pub async fn serve(&self) -> Result<()> { + let (close_tx, mut close_rx) = mpsc::channel::<()>(1); + *self.close_ch.lock().await = Some(close_tx); let addr = self.listen_addr.parse()?; info!(target: "raft::service", id = self.id; "raft gRPC server listening on {addr}"); Server::builder() .add_service(RaftServer::new(self.clone())) - .serve(addr) + .serve_with_shutdown(addr, async move { + close_rx.recv().await; + }) .await?; Ok(()) } - pub async fn serve_with_shutdown>( - &self, - f: F, - ) -> Result<()> { - let addr = self.listen_addr.parse()?; - info!(target: "raft::service", id = self.id; "raft gRPC server listening on {addr}"); - Server::builder() - .add_service(RaftServer::new(self.clone())) - .serve_with_shutdown(addr, f) - .await?; - Ok(()) + pub async fn close(&self) { + if let Some(ch) = self.close_ch.lock().await.take() { + ch.send(()).await.unwrap(); + } + let ctx = self.context.read().await; + ctx.cancel_timeout().await; + ctx.stop_tick().await; } } diff --git a/src/raft/state.rs b/src/raft/state.rs index 630f34f..b02da27 100644 --- a/src/raft/state.rs +++ b/src/raft/state.rs @@ -5,7 +5,7 @@ use super::{ RequestVoteArgs, RequestVoteReply, }; use crate::conf::Config; -use log::{debug, info}; +use log::{debug, error, info}; use serde::ser::SerializeMap; use serde::{Serialize, Serializer}; use std::fmt::Debug; @@ -33,11 +33,17 @@ pub trait State: Sync + Send + Debug { fn term(&self) -> Term; fn role(&self) -> Role; fn following(&self) -> Option; - async fn setup_timer(&self, ctx: RaftContext); + async fn setup(&self, ctx: RaftContext); /// Call with holding the lock of state async fn on_timeout(&self, ctx: RaftContext) -> Option>>; /// Call without holding the lock of state async fn on_tick(&self, ctx: RaftContext) -> Option>>; + /// Append new command + async fn on_command( + &self, + ctx: RaftContext, + cmd: Vec, + ) -> anyhow::Result>>>; async fn request_vote_logic( &self, @@ -113,12 +119,16 @@ pub trait State: Sync + Send + Debug { reason = "invalid leader term"; "reject rpc request" ); + let (last_log_index, last_log_term) = match ctx.read().await.log().latest() { + Some(log) => log, + None => (0, 0), + }; return ( AppendEntriesReply { term: self.term(), success: false, - conflict_index: 0, - conflict_term: 0, + last_log_index, + last_log_term, }, None, ); @@ -187,10 +197,15 @@ impl Serialize for Box { } } -pub fn init(cfg: Config) -> (Arc>, Arc>>>) { +pub fn init( + cfg: Config, + commit_ch: mpsc::Sender>>, +) -> (Arc>, Arc>>>) { let (timeout_tx, timeout_rx) = mpsc::channel(1); let (tick_tx, tick_rx) = mpsc::channel(1); - let context = Arc::new(RwLock::new(Context::new(cfg, timeout_tx, tick_tx))); + let context = Arc::new(RwLock::new(Context::new( + cfg, commit_ch, timeout_tx, tick_tx, + ))); let init_state = FollowerState::new(0, None); info!(target: "raft::state", @@ -206,16 +221,43 @@ pub async fn transition<'a>( mut state: MutexGuard<'a, Arc>>, new_state: Option>>, ctx: Arc>, -) { +) -> bool { if let Some(new_state) = new_state { + if new_state.term() < state.term() { + // Forbid lower term state transition + error!(target: "raft::state", + old_state:serde = (&*state as &Box), + new_state:serde = (&new_state as &Box); + "invalid state transition" + ); + return false; + } + match (state.role(), new_state.role()) { + (Role::Follower, Role::Candidate) => {} + (Role::Candidate, Role::Leader) => {} + (Role::Follower, Role::Follower) => {} + (Role::Candidate, Role::Follower) => {} + (Role::Leader, Role::Follower) => {} + _ => { + // Forbid invalid state transition + error!(target: "raft::state", + old_state:serde = (&*state as &Box), + new_state:serde = (&new_state as &Box); + "invalid state transition" + ); + return false; + } + } info!(target: "raft::state", old_state:serde = (&*state as &Box), new_state:serde = (&new_state as &Box); "state transition occurred" ); *state = new_state; - state.setup_timer(ctx).await; + state.setup(ctx).await; + return true; } + false } fn handle_timer( @@ -226,7 +268,7 @@ fn handle_timer( ) { tokio::spawn(async move { ctx.read().await.init_timer().await; - state.lock().await.setup_timer(ctx.clone()).await; + state.lock().await.setup(ctx.clone()).await; loop { tokio::select! { _ = timeout_rx.recv() => { @@ -267,7 +309,8 @@ fn handle_timer( #[tokio::test] async fn reject_lower_term() { - let (ctx, _) = init(Config::builder().peers(1).build().pop().unwrap()); + let (commit_tx, _) = mpsc::channel(1); + let (ctx, _) = init(Config::builder().peers(1).build().pop().unwrap(), commit_tx); let state = FollowerState::new(2, None); assert_eq!(2, state.term()); @@ -311,8 +354,8 @@ async fn reject_lower_term() { AppendEntriesReply { term: 2, success: false, - conflict_index: 0, - conflict_term: 0, + last_log_index: 0, + last_log_term: 0, }, reply ); diff --git a/src/raft/state/candidate.rs b/src/raft/state/candidate.rs index c349775..04b3634 100644 --- a/src/raft/state/candidate.rs +++ b/src/raft/state/candidate.rs @@ -38,7 +38,7 @@ impl State for CandidateState { None } - async fn setup_timer(&self, ctx: RaftContext) { + async fn setup(&self, ctx: RaftContext) { let timeout = config::candidate_timeout(); let tick = Duration::from_millis(config::REQUEST_VOTE_INTERVAL as u64); let ctx = ctx.read().await; @@ -52,6 +52,14 @@ impl State for CandidateState { ); } + async fn on_command( + &self, + _ctx: RaftContext, + _cmd: Vec, + ) -> anyhow::Result>>> { + Err(anyhow::anyhow!("not leader")) + } + async fn request_vote_logic( &self, _ctx: RaftContext, @@ -68,7 +76,7 @@ impl State for CandidateState { async fn append_entries_logic( &self, - _ctx: RaftContext, + ctx: RaftContext, args: AppendEntriesArgs, ) -> (AppendEntriesReply, Option>>) { info!(target: "raft::state", @@ -76,14 +84,10 @@ impl State for CandidateState { leader = args.leader_id; "other peer won current election, revert to follower" ); + let new_state = FollowerState::new(args.term.clone(), Some(args.leader_id.clone())); ( - AppendEntriesReply { - term: self.term, - success: true, - conflict_term: 0, - conflict_index: 0, - }, - Some(FollowerState::new(args.term, Some(args.leader_id))), + new_state.handle_append_entries(ctx, args).await.0, + Some(new_state), ) } @@ -133,6 +137,7 @@ impl State for CandidateState { let peer_url = peer_cli.lock().await.url(); match resp { Ok(resp) => { + // TODO: prefetch peer next index debug!(target: "raft::rpc", term = resp.term, peer = peer_url, diff --git a/src/raft/state/follower.rs b/src/raft/state/follower.rs index 625c95b..1948cec 100644 --- a/src/raft/state/follower.rs +++ b/src/raft/state/follower.rs @@ -21,18 +21,100 @@ impl FollowerState { } /// Check if the candidate's log is more up-to-date - fn is_request_valid(&self, _ctx: RaftContext, _args: &RequestVoteArgs) -> bool { - // TODO: implement this - true + async fn is_request_valid(&self, ctx: RaftContext, args: &RequestVoteArgs) -> bool { + let ctx = ctx.read().await; + let log = ctx.log(); + let (latest_index, latest_term) = match log.latest() { + Some(log) => log, + None => return true, + }; + + // From 5.4.1: + // If the logs have last entries with different terms, then + // the log with the later term is more up-to-date. + if args.last_log_term < latest_term { + return false; + } + if args.last_log_term > latest_term { + return true; + } + + // From 5.4.1: + // If the logs end with the same term, then + // whichever log is longer is more up-to-date. + return args.last_log_index >= latest_index; } /// Handle entries from leader, return if any conflict appears - fn handle_entries( + async fn handle_entries( &self, - _ctx: RaftContext, - _args: AppendEntriesArgs, + ctx: RaftContext, + args: AppendEntriesArgs, ) -> Option<(Term, LogIndex)> { - // TODO: append logs and commit + let mut ctx = ctx.write().await; + let log = ctx.log_mut(); + let (latest_index, latest_term) = match log.latest() { + Some(log) => log, + None => (0, 0), + }; + + // Reply false if log doesn’t contain an entry at prevLogIndex + // whose term matches prevLogTerm (§5.3) + if args.prev_log_index > latest_index { + info!(target: "raft::log", + state:serde = self, + prev_index = args.prev_log_index, + latest_index = latest_index; + "reject append logs: prev index is larger than last log index" + ); + return Some((latest_term, latest_index)); + } + + let prev_term = match log.term(args.prev_log_index) { + Some(term) => term, + None => { + info!(target: "raft::log", + state:serde = self, + prev_index = args.prev_log_index; + "reject append logs: prev index is already trimmed" + ); + return Some((latest_term, latest_index)); + } + }; + + // If an existing entry conflicts with a new one (same index but different terms), + // delete the existing entry and all that follow it (§5.3) + if prev_term != args.prev_log_term { + info!(target: "raft::log", + state:serde = self, + prev_term = prev_term, + prev_index = args.prev_log_index; + "reject append logs: prev term is conflict with last log term" + ); + // Fallback the whole term once a time + let term_start = log.first_log_at_term(prev_term).unwrap(); + log.delete_since(term_start); + return Some((prev_term, args.prev_log_index)); + } + + log.delete_since(args.prev_log_index + 1); + + // Append any new entries not already in the log + let entries = args.entries.len(); + if entries > 0 { + args.entries.into_iter().for_each(|entry| { + log.append(entry.term, entry.command); + }); + info!(target: "raft::log", + state:serde = self, + prev_index = args.prev_log_index , + prev_term = args.prev_log_term, + entries = entries; + "append new log [{}..{}]", + args.prev_log_index + 1, + args.prev_log_index + 1 + entries as u64 + ); + } None } } @@ -51,7 +133,7 @@ impl State for FollowerState { self.follow.clone() } - async fn setup_timer(&self, ctx: RaftContext) { + async fn setup(&self, ctx: RaftContext) { let timeout = config::follower_timeout(); let ctx = ctx.read().await; ctx.reset_timeout(timeout).await; @@ -64,6 +146,14 @@ impl State for FollowerState { ); } + async fn on_command( + &self, + _ctx: RaftContext, + _cmd: Vec, + ) -> anyhow::Result>>> { + Err(anyhow::anyhow!("not leader")) + } + async fn request_vote_logic( &self, ctx: RaftContext, @@ -80,7 +170,7 @@ impl State for FollowerState { } None => { // Not voted for any candidate yet - if self.is_request_valid(ctx.clone(), &args) { + if self.is_request_valid(ctx.clone(), &args).await { // Vote for this candidate ( true, @@ -120,7 +210,6 @@ impl State for FollowerState { let (success, new_state) = match &self.follow { Some(l) if *l == args.leader_id => { // Following this leader - // TODO: append entries to context debug!(target: "raft::rpc", state:serde = self, term = args.term, @@ -150,28 +239,39 @@ impl State for FollowerState { timeout:serde = timeout; "reset timeout timer" ); - if let Some((conflict_term, conflict_index)) = self.handle_entries(ctx, args) { + let commit_index = args.leader_commit; + if let Some((conflict_term, conflict_index)) = + self.handle_entries(ctx.clone(), args).await + { AppendEntriesReply { term: self.term, success: false, - conflict_term, - conflict_index, + last_log_index: conflict_index, + last_log_term: conflict_term, } } else { + let (last_log_index, last_log_term) = match ctx.read().await.log().latest() { + Some(log) => log, + None => (0, 0), + }; + ctx.write().await.commit_log(commit_index).await; AppendEntriesReply { term: self.term, success: true, - conflict_term: 0, - conflict_index: 0, + last_log_index, + last_log_term, } } } else { - // TODO: set latest log' index and term + let (last_log_index, last_log_term) = match ctx.read().await.log().latest() { + Some(log) => log, + None => (0, 0), + }; AppendEntriesReply { term: self.term, success: false, - conflict_term: 0, - conflict_index: 0, + last_log_index, + last_log_term, } }; (reply, new_state) @@ -200,7 +300,11 @@ impl State for FollowerState { #[tokio::test] async fn vote_request() { - let (ctx, state) = super::init(super::Config::builder().peers(1).build().pop().unwrap()); + let (commit_tx, _) = tokio::sync::mpsc::channel(1); + let (ctx, state) = super::init( + super::Config::builder().peers(1).build().pop().unwrap(), + commit_tx, + ); let state = state.lock().await; assert_eq!(state.term(), 0); @@ -233,7 +337,11 @@ async fn vote_request() { #[tokio::test] async fn append_entries() { - let (ctx, state) = super::init(super::Config::builder().peers(1).build().pop().unwrap()); + let (commit_tx, _) = tokio::sync::mpsc::channel(1); + let (ctx, state) = super::init( + super::Config::builder().peers(1).build().pop().unwrap(), + commit_tx, + ); let state = state.lock().await; assert_eq!(state.term(), 0); @@ -257,8 +365,8 @@ async fn append_entries() { AppendEntriesReply { term: 2, success: true, - conflict_index: 0, - conflict_term: 0, + last_log_index: 0, + last_log_term: 0, }, reply ); diff --git a/src/raft/state/leader.rs b/src/raft/state/leader.rs index 990083b..9765649 100644 --- a/src/raft/state/leader.rs +++ b/src/raft/state/leader.rs @@ -20,6 +20,125 @@ impl LeaderState { pub fn new(term: Term) -> Arc> { Arc::new(Box::new(Self { term })) } + + async fn sync_peers(&self, ctx: RaftContext) -> Option>> { + let peers = ctx.read().await.peers(); + let higher_term = Arc::new(Mutex::new(self.term)); + let notify = Arc::new(Notify::new()); + let mut requests = Vec::with_capacity(peers); + + for peer in 0..peers { + let ctx = ctx.clone(); + let (peer_cli, peer_next_index, mut args) = { + let ctx = ctx.read().await; + let args = AppendEntriesArgs { + term: self.term, + leader_id: ctx.me().clone(), + prev_log_index: 0, + prev_log_term: 0, + entries: vec![], + leader_commit: ctx.log().commit_index(), + }; + + (ctx.get_peer(peer), ctx.peer_next_index(peer), args) + }; + let notify = notify.clone(); + let higher_term = higher_term.clone(); + + requests.push(tokio::spawn(async move { + // Lock peer_next_index until the end of the request, + // so next rpc can send accumulated new entries during this rpc + let mut peer_next_index = peer_next_index.lock().await; + { + let ctx = ctx.read().await; + let log = ctx.log(); + let prev_index = *peer_next_index - 1; + match log.term(prev_index) { + Some(term) => { + args.prev_log_term = term; + args.prev_log_index = prev_index; + } + None => { + args.prev_log_term = args.term; + args.prev_log_index = log.snapshot_index(); + } + } + args.entries = log.since(*peer_next_index as usize); + } + let resp = peer_cli.lock().await.append_entries(args).await; + let peer_url = peer_cli.lock().await.url(); + match resp { + Ok(resp) => { + debug!(target: "raft::rpc", + term = resp.term, + peer = peer_url; + "call peer rpc AppendEntries" + ); + if !resp.success { + let mut higher_term = higher_term.lock().await; + if resp.term > *higher_term { + *higher_term = resp.term; + notify.notify_one(); + return; + } else { + // There must be a conflict + error!(target: "raft::rpc", + term = resp.term, + peer = peer_url, + conflict_index = resp.last_log_index, + conflict_term = resp.last_log_term; + "meet peer log conflict" + ); + } + } else { + // Update next index + info!(target: "raft::rpc", + term = resp.term, + peer = peer_url, + last_log_index = resp.last_log_index, + last_log_term = resp.last_log_term; + "sync peer log" + ); + } + *peer_next_index = resp.last_log_index + 1; + ctx.write() + .await + .update_peer_index(peer, resp.last_log_index) + .await; + } + Err(e) => { + error!(target: "raft::rpc", + error:err = e, + peer = peer_url; + "peer rpc AppendEntries error" + ); + } + } + })); + } + + let finish_notify = notify.clone(); + tokio::spawn(async move { + future::join_all(requests).await; + + // Wake up the main task if all requests are done + finish_notify.notify_one(); + }); + + notify.notified().await; + + let higher_term = *higher_term.lock().await; + if higher_term > self.term { + info!(target: "raft::rpc", + term = self.term, + new_term = higher_term; + "meet higher term, revert to follower" + ); + Some(FollowerState::new(higher_term, None)) + } else { + None + } + } } #[tonic::async_trait] @@ -36,9 +155,13 @@ impl State for LeaderState { None } - async fn setup_timer(&self, ctx: RaftContext) { + async fn setup(&self, ctx: RaftContext) { let tick = Duration::from_millis(config::HEARTBEAT_INTERVAL as u64); let ctx = ctx.read().await; + let init_next_index = ctx.log().latest().unwrap().0 + 1; + for peer in 0..ctx.peers() { + *ctx.peer_next_index(peer).lock().await = init_next_index; + } ctx.cancel_timeout().await; ctx.reset_tick(tick).await; debug!(target: "raft::state", @@ -49,6 +172,21 @@ impl State for LeaderState { ); } + async fn on_command( + &self, + ctx: RaftContext, + cmd: Vec, + ) -> anyhow::Result>>> { + let index = ctx.write().await.log_mut().append(self.term(), cmd); + info!(target: "raft::state", + state:serde = self, + index = index; + "command appended, wait sync to peers" + ); + self.sync_peers(ctx).await; + Ok(None) + } + async fn request_vote_logic( &self, _ctx: RaftContext, @@ -78,78 +216,6 @@ impl State for LeaderState { } async fn on_tick(&self, ctx: RaftContext) -> Option>> { - let (peers, me) = { - let ctx = ctx.read().await; - (ctx.peers(), ctx.me().clone()) - }; - let higher_term = Arc::new(Mutex::new(self.term)); - let notify = Arc::new(Notify::new()); - let mut requests = Vec::with_capacity(peers); - let args = AppendEntriesArgs { - term: self.term, - leader_id: me, - prev_log_index: 0, - prev_log_term: 0, - entries: vec![], - leader_commit: 0, - }; - - for peer in 0..peers { - let peer_cli = ctx.read().await.get_peer(peer); - let notify = notify.clone(); - let args = args.clone(); - let higher_term = higher_term.clone(); - - requests.push(tokio::spawn(async move { - let resp = peer_cli.lock().await.append_entries(args).await; - let peer_url = peer_cli.lock().await.url(); - match resp { - Ok(resp) => { - if !resp.success { - let mut higher_term = higher_term.lock().await; - if resp.term > *higher_term { - *higher_term = resp.term; - notify.notify_one(); - return; - } - } - debug!(target: "raft::rpc", - term = resp.term, - peer = peer_url; - "send heartbeat to peer" - ); - } - Err(e) => { - error!(target: "raft::rpc", - error:err = e, - peer = peer_url; - "call peer rpc AppendEntries(hb) error" - ); - } - } - })); - } - - let finish_notify = notify.clone(); - tokio::spawn(async move { - future::join_all(requests).await; - - // Wake up the main task if all requests are done - finish_notify.notify_one(); - }); - - notify.notified().await; - - let higher_term = *higher_term.lock().await; - if higher_term > self.term { - info!(target: "raft::rpc", - term = self.term, - new_term = higher_term; - "meet higher term, revert to follower" - ); - Some(FollowerState::new(higher_term, None)) - } else { - None - } + self.sync_peers(ctx).await } } diff --git a/tests/raft_test.rs b/tests/raft_test.rs index bc33de9..3110b80 100644 --- a/tests/raft_test.rs +++ b/tests/raft_test.rs @@ -1,91 +1,195 @@ use radis::conf::Config; use radis::raft::state; use radis::raft::RaftService; +use std::sync::Arc; use std::time::Duration; use tokio::sync::mpsc; use tokio::time::sleep; -struct ServiceHandler { - service: RaftService, - close_ch: mpsc::Sender<()>, +#[tokio::test] +async fn leader_election() { + let mut ctl = Controller::new(3, 50000); + ctl.serve_all().await; + + // Wait for establishing agreement on one leader + sleep(Duration::from_millis(1000)).await; + + let (followers, candidates, leader_cnt) = ctl.count_roles().await; + // There should be only one leader, and all others are followers + assert_eq!(followers, 2); + assert_eq!(candidates, 0); + assert_eq!(leader_cnt, 1); + + ctl.close_all().await; + // Wait for all connections to finish + sleep(Duration::from_millis(500)).await; } -impl ServiceHandler { - #[allow(dead_code)] - async fn serve(&self) { - self.service.serve().await.unwrap(); - } +#[tokio::test] +async fn fail_over() { + let mut ctl = Controller::new(3, 50003); + ctl.serve_all().await; - async fn close(&self) { - self.close_ch.send(()).await.unwrap(); - } + // Wait for establishing agreement on one leader + sleep(Duration::from_millis(1000)).await; + + let (followers, candidates, leader_cnt) = ctl.count_roles().await; + // There should be only one leader, and all others are followers + assert_eq!(followers, 2); + assert_eq!(candidates, 0); + assert_eq!(leader_cnt, 1); + + // Leader offline + let old_leader = ctl.leader().await.unwrap(); + let old_term = ctl.term(old_leader).await; + ctl.close(old_leader).await; + + // Wait for one follower timeout and start new election + sleep(Duration::from_millis(1000)).await; + + let (followers, candidates, leader_cnt) = ctl.count_roles().await; + // There should be a new leader got elected + assert_eq!(followers, 1); + assert_eq!(candidates, 0); + assert_eq!(leader_cnt, 2); // New leader plus old leader + + // Old leader back online + ctl.serve(old_leader).await; + ctl.setup_timer(old_leader).await; + + // Wait for re-establishing agreement on new leader + sleep(Duration::from_millis(1000)).await; + + let new_leader = ctl.leader().await.unwrap(); + let new_term = ctl.term(new_leader).await; + // The old leader should be replaced by the new leader with a higher term + assert_ne!(old_leader, new_leader); + assert!(new_term > old_term); + + ctl.close_all().await; + // Wait for all connections to finish + sleep(Duration::from_millis(500)).await; +} + +#[tokio::test] +async fn basic_commit() { + let peers = 3; + let mut ctl = Controller::new(peers, 50006); + ctl.serve_all().await; + + // Wait for establishing agreement on one leader + sleep(Duration::from_millis(1000)).await; + let leader = ctl.leader().await.unwrap(); - async fn role(&self) -> state::Role { - self.service.state().lock().await.role() + let data = b"hello, raft!".to_vec(); + + ctl.append_command(leader, data.clone()).await; + let recv_data = ctl.read_commit(leader).await.unwrap(); + assert!(recv_data.as_slice() == data.as_slice()); + + // Wait for commit sync to peers + sleep(Duration::from_millis(500)).await; + for idx in (0..peers).filter(|idx| *idx != (leader as i32)) { + let recv_data = ctl.read_commit(idx as usize).await.unwrap(); + assert!(recv_data.as_slice() == data.as_slice()); } } -async fn start(peers: i32) -> Vec { - let services: Vec = Config::builder() - .peers(peers) - .build() - .iter() - .map(|cfg| RaftService::new(cfg.clone())) - .collect(); - - let mut handlers = Vec::new(); - for service in services { - let (close_tx, mut close_rx) = mpsc::channel(1); - let srv = service.clone(); +//////////////////////////////////////////////////////////////////////////////// + +struct Controller { + services: Vec, + commit_rxs: Vec>>>, +} + +impl Controller { + fn new(peers: i32, port_base: u16) -> Self { + let mut services = Vec::with_capacity(peers as usize); + let mut commit_rxs = Vec::with_capacity(peers as usize); + for cfg in Config::builder() + .peers(peers) + .base_port(port_base) + .build() + .into_iter() + { + let (commit_tx, commit_rx) = mpsc::channel(1); + services.push(RaftService::new(cfg, commit_tx)); + commit_rxs.push(commit_rx); + } + Controller { + services, + commit_rxs, + } + } + + async fn serve(&mut self, idx: usize) { + let srv = self.services[idx].clone(); tokio::spawn(async move { - srv.serve_with_shutdown(async move { - close_rx.recv().await; - }) - .await - .unwrap(); + srv.serve().await.unwrap(); }); - handlers.push(ServiceHandler { - service, - close_ch: close_tx, - }) } - handlers -} + async fn serve_all(&mut self) { + for i in 0..self.services.len() { + self.serve(i).await; + } + } + + async fn close(&self, idx: usize) { + self.services[idx].close().await; + } -async fn count_roles(services: &Vec) -> (i32, i32, i32) { - let mut leader_cnt = 0; - let mut followers = 0; - let mut candidates = 0; - for service in services { - match service.role().await { - state::Role::Leader => leader_cnt += 1, - state::Role::Follower => followers += 1, - state::Role::Candidate => candidates += 1, + async fn close_all(&self) { + for i in 0..self.services.len() { + self.close(i).await; } } - (leader_cnt, followers, candidates) -} -//////////////////////////////////////////////////////////////////////////////// + async fn setup_timer(&self, idx: usize) { + let srv = self.services[idx].clone(); + let state = srv.state(); + let ctx = srv.context(); + let state = state.lock().await; + state.on_timeout(ctx.clone()).await; + state.on_tick(ctx.clone()).await; + } -#[tokio::test] -async fn leader_election() { - let handlers = start(3).await; + async fn term(&self, idx: usize) -> u64 { + self.services[idx].state().lock().await.term() + } - // Wait for establishing agreement on one leader - sleep(Duration::from_secs(1)).await; + async fn role(&self, idx: usize) -> state::Role { + self.services[idx].state().lock().await.role() + } - let (leader_cnt, followers, candidates) = count_roles(&handlers).await; + async fn count_roles(&self) -> (i32, i32, i32) { + let mut followers = 0; + let mut candidates = 0; + let mut leader_cnt = 0; + for i in 0..self.services.len() { + match self.role(i).await { + state::Role::Follower => followers += 1, + state::Role::Candidate => candidates += 1, + state::Role::Leader => leader_cnt += 1, + } + } + (followers, candidates, leader_cnt) + } - assert_eq!(leader_cnt, 1); - assert_eq!(followers, 2); - assert_eq!(candidates, 0); + async fn leader(&self) -> Option { + for i in 0..self.services.len() { + if self.role(i).await == state::Role::Leader { + return Some(i); + } + } + None + } - for handler in handlers { - handler.close().await; + async fn append_command(&self, idx: usize, cmd: Vec) { + self.services[idx].append_command(cmd).await.unwrap(); } - // Wait for all connections to finish - sleep(Duration::from_secs(1)).await; + async fn read_commit(&mut self, idx: usize) -> Option>> { + self.commit_rxs[idx].recv().await + } }