From 54cb6c52a442a5b4ad452964eab998e471471df5 Mon Sep 17 00:00:00 2001 From: Andrew Schran Date: Wed, 12 Jun 2024 14:00:51 -0400 Subject: [PATCH] Rework randomness state machine & network loop to use watermarks (#18083) ## Description We don't need to track individual pending rounds, everything can operate more simply with tracking highest completed round (as reported by checkpoint executor) and highest requested round (as reported by consensus handler). This change resolves a race bug when checkpoint execution is ahead of consensus. ## Test plan Verified with unit tests. Manual test in synthetic environment. --- .../authority/authority_per_epoch_store.rs | 32 +- crates/sui-core/src/epoch/randomness.rs | 24 +- crates/sui-network/src/randomness/builder.rs | 5 +- crates/sui-network/src/randomness/mod.rs | 281 +++++++++--------- crates/sui-types/src/crypto.rs | 14 + 5 files changed, 166 insertions(+), 190 deletions(-) diff --git a/crates/sui-core/src/authority/authority_per_epoch_store.rs b/crates/sui-core/src/authority/authority_per_epoch_store.rs index 912007d6c25d8..35559eb42c709 100644 --- a/crates/sui-core/src/authority/authority_per_epoch_store.rs +++ b/crates/sui-core/src/authority/authority_per_epoch_store.rs @@ -511,8 +511,9 @@ pub struct AuthorityEpochTables { /// Records the final output of DKG after completion, including the public VSS key and /// any local private shares. pub(crate) dkg_output: DBMap>, - /// RandomnessRound numbers that are still pending generation. - pub(crate) randomness_rounds_pending: DBMap, + /// This table is no longer used (can be removed when DBMap supports removing tables) + #[allow(dead_code)] + randomness_rounds_pending: DBMap, /// Holds the value of the next RandomnessRound to be generated. pub(crate) randomness_next_round: DBMap, /// Holds the value of the highest completed RandomnessRound (as reported to RandomnessReporter). @@ -666,31 +667,6 @@ impl AuthorityEpochTables { batch.write()?; Ok(()) } - - pub fn check_and_fix_consistency(&self) { - if let Some(randomness_highest_completed_round) = self - .randomness_highest_completed_round - .get(&crate::epoch::randomness::SINGLETON_KEY) - .expect("typed_store should not fail") - { - let old_randomness_rounds = self - .randomness_rounds_pending - .unbounded_iter() - .map(|(round, _)| round) - .take_while(|round| *round <= randomness_highest_completed_round) - .collect::>(); - // TODO: enable this debug_assert once race is fixed. - // debug_assert!(old_randomness_rounds.is_empty()); - if !old_randomness_rounds.is_empty() { - error!("Found {} pending randomness rounds that are older than the highest completed round {randomness_highest_completed_round}. Removing them now.", old_randomness_rounds.len()); - }; - for round in old_randomness_rounds { - self.randomness_rounds_pending - .remove(&round) - .expect("typed_store should not fail"); - } - } - } } pub(crate) const MUTEX_TABLE_SIZE: usize = 1024; @@ -715,8 +691,6 @@ impl AuthorityPerEpochStore { let epoch_id = committee.epoch; let tables = AuthorityEpochTables::open(epoch_id, parent_path, db_options.clone()); - tables.check_and_fix_consistency(); - let end_of_publish = StakeAggregator::from_iter(committee.clone(), tables.end_of_publish.unbounded_iter()); let reconfig_state = tables diff --git a/crates/sui-core/src/epoch/randomness.rs b/crates/sui-core/src/epoch/randomness.rs index 37827bf18c2ad..decc2e0d723c9 100644 --- a/crates/sui-core/src/epoch/randomness.rs +++ b/crates/sui-core/src/epoch/randomness.rs @@ -278,13 +278,15 @@ impl RandomnessManager { "random beacon: starting from next_randomness_round={}", rm.next_randomness_round.0 ); - for result in tables.randomness_rounds_pending.safe_iter() { - let (round, _) = result.expect("typed_store should not fail"); + if highest_completed_round + 1 < rm.next_randomness_round { info!( - "random beacon: resuming generation for randomness round {}", - round.0 + "random beacon: resuming generation for randomness rounds from {} to {}", + highest_completed_round + 1, + rm.next_randomness_round - 1, ); - network_handle.send_partial_signatures(committee.epoch(), round); + for r in highest_completed_round.0 + 1..rm.next_randomness_round.0 { + network_handle.send_partial_signatures(committee.epoch(), RandomnessRound(r)); + } } Some(rm) @@ -589,10 +591,6 @@ impl RandomnessManager { .checked_add(1) .expect("RandomnessRound should not overflow"); - batch.insert_batch( - &tables.randomness_rounds_pending, - std::iter::once((randomness_round, ())), - )?; batch.insert_batch( &tables.randomness_next_round, std::iter::once((SINGLETON_KEY, self.next_randomness_round)), @@ -693,10 +691,6 @@ impl RandomnessReporter { .epoch_store .upgrade() .ok_or(SuiError::EpochEnded(self.epoch))?; - epoch_store - .tables()? - .randomness_rounds_pending - .remove(&round)?; let mut highest_completed_round = self.highest_completed_round.lock(); if round > *highest_completed_round { *highest_completed_round = round; @@ -704,9 +698,9 @@ impl RandomnessReporter { .tables()? .randomness_highest_completed_round .insert(&SINGLETON_KEY, &highest_completed_round)?; + self.network_handle + .complete_round(epoch_store.committee().epoch(), round); } - self.network_handle - .complete_round(epoch_store.committee().epoch(), round); Ok(()) } } diff --git a/crates/sui-network/src/randomness/builder.rs b/crates/sui-network/src/randomness/builder.rs index 90c0d24171028..2cc96f7733881 100644 --- a/crates/sui-network/src/randomness/builder.rs +++ b/crates/sui-network/src/randomness/builder.rs @@ -127,14 +127,13 @@ impl UnstartedRandomness { peer_share_ids: None, dkg_output: None, aggregation_threshold: 0, - pending_tasks: BTreeSet::new(), + highest_requested_round: BTreeMap::new(), send_tasks: BTreeMap::new(), round_request_time: BTreeMap::new(), future_epoch_partial_sigs: BTreeMap::new(), received_partial_sigs: BTreeMap::new(), completed_sigs: BTreeSet::new(), - completed_rounds: BTreeSet::new(), - recovered_last_completed_round: None, + highest_completed_round: BTreeMap::new(), }, handle, ) diff --git a/crates/sui-network/src/randomness/mod.rs b/crates/sui-network/src/randomness/mod.rs index bfe42ae948405..93c789a443ca3 100644 --- a/crates/sui-network/src/randomness/mod.rs +++ b/crates/sui-network/src/randomness/mod.rs @@ -127,7 +127,7 @@ enum RandomnessMessage { HashMap, dkg::Output, u16, // aggregation_threshold - Option, // recovered_last_completed_round + Option, // recovered_highest_completed_round ), SendPartialSignatures(EpochId, RandomnessRound), CompleteRound(EpochId, RandomnessRound), @@ -148,15 +148,13 @@ struct RandomnessEventLoop { peer_share_ids: Option>>, dkg_output: Option>, aggregation_threshold: u16, - pending_tasks: BTreeSet<(EpochId, RandomnessRound)>, - send_tasks: BTreeMap<(EpochId, RandomnessRound), tokio::task::JoinHandle<()>>, + highest_requested_round: BTreeMap, + send_tasks: BTreeMap>, round_request_time: BTreeMap<(EpochId, RandomnessRound), time::Instant>, future_epoch_partial_sigs: BTreeMap<(EpochId, RandomnessRound, PeerId), Vec>>, - received_partial_sigs: - BTreeMap<(EpochId, RandomnessRound, PeerId), Vec>, + received_partial_sigs: BTreeMap<(RandomnessRound, PeerId), Vec>, completed_sigs: BTreeSet<(EpochId, RandomnessRound)>, - completed_rounds: BTreeSet<(EpochId, RandomnessRound)>, - recovered_last_completed_round: Option, // reported by RandomnessManager on crash recovery + highest_completed_round: BTreeMap, } impl RandomnessEventLoop { @@ -187,14 +185,14 @@ impl RandomnessEventLoop { authority_info, dkg_output, aggregation_threshold, - recovered_last_completed_round, + recovered_highest_completed_round, ) => { if let Err(e) = self.update_epoch( epoch, authority_info, dkg_output, aggregation_threshold, - recovered_last_completed_round, + recovered_highest_completed_round, ) { error!("BUG: failed to update epoch in RandomnessEventLoop: {e:?}"); } @@ -216,7 +214,7 @@ impl RandomnessEventLoop { authority_info: HashMap, dkg_output: dkg::Output, aggregation_threshold: u16, - recovered_last_completed_round: Option, + recovered_highest_completed_round: Option, ) -> Result<()> { assert!(self.dkg_output.is_none() || new_epoch > self.epoch); @@ -243,25 +241,27 @@ impl RandomnessEventLoop { self.authority_info = Arc::new(authority_info); self.dkg_output = Some(dkg_output); self.aggregation_threshold = aggregation_threshold; - self.recovered_last_completed_round = recovered_last_completed_round; + if let Some(round) = recovered_highest_completed_round { + self.highest_completed_round + .entry(new_epoch) + .and_modify(|r| *r = std::cmp::max(*r, round)) + .or_insert(round); + } for (_, task) in std::mem::take(&mut self.send_tasks) { task.abort(); } self.metrics.set_epoch(new_epoch); // Throw away info from old epochs. + self.highest_requested_round = self.highest_requested_round.split_off(&new_epoch); self.round_request_time = self .round_request_time .split_off(&(new_epoch, RandomnessRound(0))); - self.received_partial_sigs = - self.received_partial_sigs - .split_off(&(new_epoch, RandomnessRound(0), PeerId([0; 32]))); + self.received_partial_sigs.clear(); self.completed_sigs = self .completed_sigs .split_off(&(new_epoch, RandomnessRound(0))); - self.completed_rounds = self - .completed_rounds - .split_off(&(new_epoch, RandomnessRound(0))); + self.highest_completed_round = self.highest_completed_round.split_off(&new_epoch); // Start any pending tasks for the new epoch. self.maybe_start_pending_tasks(); @@ -275,21 +275,9 @@ impl RandomnessEventLoop { // We can fully validate these now that we have current epoch DKG output. self.receive_partial_signatures(peer_id, epoch, round, sig_bytes); } - let mut aggregate_rounds = BTreeSet::new(); - for (epoch, round, _) in self.received_partial_sigs.keys() { - if *epoch < new_epoch { - error!("BUG: received partial sigs for old epoch still present after attempting to remove them"); - debug_assert!(false, "received partial sigs for old epoch still present after attempting to remove them"); - continue; - } - if *epoch > new_epoch { - break; - } - if !self.completed_sigs.contains(&(*epoch, *round)) { - aggregate_rounds.insert(*round); - } - } - for round in aggregate_rounds { + let rounds_to_aggregate: Vec<_> = + self.received_partial_sigs.keys().map(|(r, _)| *r).collect(); + for round in rounds_to_aggregate { self.maybe_aggregate_partial_signatures(new_epoch, round); } @@ -309,12 +297,19 @@ impl RandomnessEventLoop { ); return; } - if self.completed_rounds.contains(&(epoch, round)) { - info!("skipping sending partial sigs, we already have completed this round"); - return; + if epoch == self.epoch { + if let Some(highest_completed_round) = self.highest_completed_round.get(&epoch) { + if round <= *highest_completed_round { + info!("skipping sending partial sigs, we already have completed this round"); + return; + } + } } - self.pending_tasks.insert((epoch, round)); + self.highest_requested_round + .entry(epoch) + .and_modify(|r| *r = std::cmp::max(*r, round)) + .or_insert(round); self.round_request_time .insert((epoch, round), time::Instant::now()); self.maybe_start_pending_tasks(); @@ -323,24 +318,32 @@ impl RandomnessEventLoop { #[instrument(level = "debug", skip_all, fields(?epoch, ?round))] fn complete_round(&mut self, epoch: EpochId, round: RandomnessRound) { debug!("completing randomness round"); - self.pending_tasks.remove(&(epoch, round)); - self.round_request_time.remove(&(epoch, round)); - self.completed_rounds.insert((epoch, round)); + let new_highest_round = *self + .highest_completed_round + .entry(epoch) + .and_modify(|r| *r = std::cmp::max(*r, round)) + .or_insert(round); + if round != new_highest_round { + // This round completion came out of order, and we're already ahead. Nothing more + // to do in that case. + return; + } - // In case we first received the full sig from a checkpoint instead of aggregating it - // locally, update related data structures here. - self.completed_sigs.insert((epoch, round)); - self.remove_partial_sigs_in_range(( - Bound::Included((epoch, round, PeerId([0; 32]))), - Bound::Excluded((epoch, round + 1, PeerId([0; 32]))), - )); + self.round_request_time = self.round_request_time.split_off(&(epoch, round + 1)); - if let Some(task) = self.send_tasks.remove(&(epoch, round)) { - task.abort(); + if epoch == self.epoch { + self.remove_partial_sigs_in_range(( + Bound::Included((RandomnessRound(0), PeerId([0; 32]))), + Bound::Excluded((round + 1, PeerId([0; 32]))), + )); + for (_, task) in self.send_tasks.iter().take_while(|(r, _)| **r <= round) { + task.abort(); + } + self.send_tasks = self.send_tasks.split_off(&(round + 1)); self.maybe_start_pending_tasks(); - } else { - self.update_rounds_pending_metric(); } + + self.update_rounds_pending_metric(); } #[instrument(level = "debug", skip_all, fields(?peer_id, ?epoch, ?round))] @@ -370,6 +373,13 @@ impl RandomnessEventLoop { debug!("skipping received partial sigs, we already have completed this sig"); return; } + let highest_completed_round = self.highest_completed_round.get(&epoch).copied(); + if let Some(highest_completed_round) = &highest_completed_round { + if *highest_completed_round >= round { + debug!("skipping received partial sigs, we already have completed this round"); + return; + } + } // If sigs are for a future epoch, we can't fully verify them without DKG output. // Save them for later use. @@ -401,37 +411,24 @@ impl RandomnessEventLoop { ); return; } - let (last_completed_epoch, last_completed_round) = match self.completed_sigs.last() { - Some((last_completed_epoch, last_completed_round)) => { - (*last_completed_epoch, *last_completed_round) - } - // If we just changed epochs and haven't completed any sigs yet, or if we - // restarted mid-epoch, this will be used. - None => ( - self.epoch, - // We don't store completed sigs durably outside of checkpoints, so after a - // restart we use the last completed round instead. This is okay because - // incomplete rounds with previously-completed sigs will be re-opened - // by the RandomnessManager on restart, and we'll simply repeat the process. - self.recovered_last_completed_round - .unwrap_or(RandomnessRound(0)), - ), - }; - if epoch == last_completed_epoch - && round.0 - >= last_completed_round - .0 - .saturating_add(self.config.max_partial_sigs_rounds_ahead()) + + // Accept partial signatures up to `max_partial_sigs_rounds_ahead` past the round of the + // last completed signature, or the highest completed round, whichever is greater. + let last_completed_signature = self + .completed_sigs + .range(..&(epoch + 1, RandomnessRound(0))) + .next_back() + .map(|(e, r)| if *e == epoch { *r } else { RandomnessRound(0) }); + let last_completed_round = std::cmp::max(last_completed_signature, highest_completed_round) + .unwrap_or(RandomnessRound(0)); + if round.0 + >= last_completed_round + .0 + .saturating_add(self.config.max_partial_sigs_rounds_ahead()) { debug!( - "skipping received partial sigs, most recent round we completed was only {last_completed_round}", - ); - return; - } - if epoch > last_completed_epoch && round.0 >= self.config.max_partial_sigs_rounds_ahead() { - debug!( - "skipping received partial sigs, most recent epoch we completed was only {last_completed_epoch}", - ); + "skipping received partial sigs, most recent round we completed was only {last_completed_round}", + ); return; } @@ -465,7 +462,7 @@ impl RandomnessEventLoop { // We passed all the checks, save the partial sigs. debug!("recording received partial signatures"); self.received_partial_sigs - .insert((epoch, round, peer_id), partial_sigs); + .insert((round, peer_id), partial_sigs); self.maybe_aggregate_partial_signatures(epoch, round); } @@ -473,17 +470,19 @@ impl RandomnessEventLoop { #[instrument(level = "debug", skip_all, fields(?epoch, ?round))] fn maybe_aggregate_partial_signatures(&mut self, epoch: EpochId, round: RandomnessRound) { if self.completed_sigs.contains(&(epoch, round)) { - error!("BUG: called maybe_aggregate_partial_signatures for already-completed round"); - debug_assert!( - false, - "called maybe_aggregate_partial_signatures for already-completed round" - ); + info!("skipping aggregation for already-completed signature"); return; } - if !(self.send_tasks.contains_key(&(epoch, round)) - || self.pending_tasks.contains(&(epoch, round))) - { + if let Some(highest_completed_round) = self.highest_completed_round.get(&epoch) { + if round <= *highest_completed_round { + info!("skipping aggregation for already-completed round"); + return; + } + } + + let highest_requested_round = self.highest_requested_round.get(&epoch); + if highest_requested_round.is_none() || round > *highest_requested_round.unwrap() { // We have to wait here, because even if we have enough information from other nodes // to complete the signature, local shared object versions are not set until consensus // finishes processing the corresponding commit. This function will be called again @@ -492,6 +491,13 @@ impl RandomnessEventLoop { return; } + if epoch != self.epoch { + debug!( + "waiting to aggregate randomness partial signatures until DKG completes for epoch" + ); + return; + } + let vss_pk = { let Some(dkg_output) = &self.dkg_output else { debug!("called maybe_aggregate_partial_signatures before DKG completed"); @@ -501,8 +507,8 @@ impl RandomnessEventLoop { }; let sig_bounds = ( - Bound::Included((epoch, round, PeerId([0; 32]))), - Bound::Excluded((epoch, round + 1, PeerId([0; 32]))), + Bound::Included((round, PeerId([0; 32]))), + Bound::Excluded((round + 1, PeerId([0; 32]))), ); // If we have enough partial signatures, aggregate them. @@ -526,8 +532,8 @@ impl RandomnessEventLoop { // one-by-one to find which. // TODO: add test for individual sig verification. self.received_partial_sigs - .retain(|&(e, r, peer_id), partial_sigs| { - if epoch != e || round != r { + .retain(|&(r, peer_id), partial_sigs| { + if round != r { return true; } if ThresholdBls12381MinSig::partial_verify_batch( @@ -587,50 +593,48 @@ impl RandomnessEventLoop { let dkg_output = if let Some(dkg_output) = &self.dkg_output { dkg_output } else { - return; // can't start tasks until first DKG completes + return; // wait for DKG }; let shares = if let Some(shares) = &dkg_output.shares { shares } else { return; // can't participate in randomness generation without shares }; + let highest_requested_round = + if let Some(highest_requested_round) = self.highest_requested_round.get(&self.epoch) { + highest_requested_round + } else { + return; // no rounds to start + }; + // Begin from the next round after the most recent one we've started (or, if none are running, + // after the highest completed round in the epoch). + let start_round = std::cmp::max( + if let Some(highest_completed_round) = self.highest_completed_round.get(&self.epoch) { + highest_completed_round.checked_add(1).unwrap() + } else { + RandomnessRound(0) + }, + self.send_tasks + .last_key_value() + .map(|(r, _)| r.checked_add(1).unwrap()) + .unwrap_or(RandomnessRound(0)), + ); - let mut last_handled_key = None; let mut rounds_to_aggregate = Vec::new(); - for (epoch, round) in &self.pending_tasks { - if epoch > &self.epoch { - break; // wait for DKG in new epoch - } + for round in start_round.0..=highest_requested_round.0 { + let round = RandomnessRound(round); if self.send_tasks.len() >= self.config.max_partial_sigs_concurrent_sends() { break; // limit concurrent tasks } - last_handled_key = Some((*epoch, *round)); - - if epoch < &self.epoch { - info!( - "skipping sending partial sigs for epoch {epoch} round {round}, we are already up to epoch {}", - self.epoch - ); - continue; - } - - if self.completed_rounds.contains(&(*epoch, *round)) { - info!( - "skipping sending partial sigs for epoch {epoch} round {round}, we already have completed this round", - ); - continue; - } - - self.send_tasks.entry((*epoch, *round)).or_insert_with(|| { + self.send_tasks.entry(round).or_insert_with(|| { let name = self.name; let network = self.network.clone(); let retry_interval = self.config.partial_signature_retry_interval(); let metrics = self.metrics.clone(); let authority_info = self.authority_info.clone(); - let epoch = *epoch; - let round = *round; + let epoch = self.epoch; let partial_sigs = ThresholdBls12381MinSig::partial_sign_batch( shares.iter(), &round.signature_message(), @@ -639,7 +643,7 @@ impl RandomnessEventLoop { // Record own partial sigs. if !self.completed_sigs.contains(&(epoch, round)) { self.received_partial_sigs - .insert((epoch, round, self.network.peer_id()), partial_sigs.clone()); + .insert((round, self.network.peer_id()), partial_sigs.clone()); rounds_to_aggregate.push((epoch, round)); } @@ -657,19 +661,6 @@ impl RandomnessEventLoop { }); } - if let Some(last_handled_key) = last_handled_key { - // Remove stuff from the pending_tasks map that we've handled. - let split_point = self - .pending_tasks - .range((Bound::Excluded(last_handled_key), Bound::Unbounded)) - .next() - .cloned(); - if let Some(key) = split_point { - self.pending_tasks = self.pending_tasks.split_off(&key); - } else { - self.pending_tasks.clear(); - } - } self.update_rounds_pending_metric(); // After starting a round, we have generated our own partial sigs. Check if that's @@ -683,8 +674,8 @@ impl RandomnessEventLoop { fn remove_partial_sigs_in_range( &mut self, range: ( - Bound<(u64, RandomnessRound, PeerId)>, - Bound<(u64, RandomnessRound, PeerId)>, + Bound<(RandomnessRound, PeerId)>, + Bound<(RandomnessRound, PeerId)>, ), ) { let keys_to_remove: Vec<_> = self @@ -764,21 +755,25 @@ impl RandomnessEventLoop { } fn update_rounds_pending_metric(&self) { - let num_rounds_pending = (self.pending_tasks.len() + self.send_tasks.len()) as i64; + let highest_requested_round = self + .highest_requested_round + .get(&self.epoch) + .map(|r| r.0) + .unwrap_or(0); + let highest_completed_round = self + .highest_completed_round + .get(&self.epoch) + .map(|r| r.0) + .unwrap_or(0); + let num_rounds_pending = + highest_requested_round.saturating_sub(highest_completed_round) as i64; let prev_value = self.metrics.num_rounds_pending().unwrap_or_default(); if num_rounds_pending / 100 > prev_value / 100 { warn!( // Recording multiples of 100 so tests can match on the log message. "RandomnessEventLoop randomness generation backlog: over {} rounds are pending (oldest is {:?})", (num_rounds_pending / 100) * 100, - match (self.pending_tasks.first(), self.send_tasks.first_key_value()) { - (Some(p), Some((s, _))) => { - std::cmp::min(p, s) - } - (Some(p), None) => p, - (None, Some((s, _))) => s, - (None, None) => &(0, RandomnessRound(0)), - }, + highest_completed_round+1, ); } self.metrics.set_num_rounds_pending(num_rounds_pending); diff --git a/crates/sui-types/src/crypto.rs b/crates/sui-types/src/crypto.rs index b4732bc05a80f..686c3b7760c91 100644 --- a/crates/sui-types/src/crypto.rs +++ b/crates/sui-types/src/crypto.rs @@ -1749,6 +1749,20 @@ impl std::ops::Add for RandomnessRound { } } +impl std::ops::Sub for RandomnessRound { + type Output = Self; + fn sub(self, other: Self) -> Self { + Self(self.0 - other.0) + } +} + +impl std::ops::Sub for RandomnessRound { + type Output = Self; + fn sub(self, other: u64) -> Self { + Self(self.0 - other) + } +} + impl RandomnessRound { pub fn new(round: u64) -> Self { Self(round)