Skip to content

Commit

Permalink
Tree states optimization using EpochCache (#4429)
Browse files Browse the repository at this point in the history
* Relocate epoch cache to BeaconState

* Optimize per block processing by pulling previous epoch & current epoch calculation up.

* Revert `get_cow` change (no performance improvement)

* Initialize `EpochCache` in epoch processing and load it from state when getting base rewards.

* Initialize `EpochCache` at start of block processing if required.

* Initialize `EpochCache` in `transition_blocks` if `exclude_cache_builds` is enabled

* Fix epoch cache initialization logic

* Remove FIXME comment.

* Cache previous & current epochs in `consensus_context.rs`.

* Move `get_base_rewards` from `ConsensusContext` to `BeaconState`.

* Update Milhouse version
  • Loading branch information
jimmygchen committed Jun 30, 2023
1 parent 160bbde commit 2df714e
Show file tree
Hide file tree
Showing 19 changed files with 239 additions and 198 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 9 additions & 6 deletions beacon_node/beacon_chain/src/beacon_block_reward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use state_processing::{
per_block_processing::{
altair::sync_committee::compute_sync_aggregate_rewards, get_slashable_indices,
},
ConsensusContext,
};
use store::{
consts::altair::{PARTICIPATION_FLAG_WEIGHTS, PROPOSER_WEIGHT, WEIGHT_DENOMINATOR},
Expand Down Expand Up @@ -177,8 +176,6 @@ impl<T: BeaconChainTypes> BeaconChain<T> {
block: BeaconBlockRef<'_, T::EthSpec, Payload>,
state: &mut BeaconState<T::EthSpec>,
) -> Result<BeaconBlockSubRewardValue, BeaconChainError> {
let mut ctxt = ConsensusContext::new(block.slot());

let mut total_proposer_reward = 0;

let proposer_reward_denominator = WEIGHT_DENOMINATOR
Expand All @@ -202,8 +199,13 @@ impl<T: BeaconChainTypes> BeaconChain<T> {
for index in attesting_indices {
let index = index as usize;
for (flag_index, &weight) in PARTICIPATION_FLAG_WEIGHTS.iter().enumerate() {
let epoch_participation =
state.get_epoch_participation_mut(data.target.epoch)?;
let previous_epoch = state.previous_epoch();
let current_epoch = state.current_epoch();
let epoch_participation = state.get_epoch_participation_mut(
data.target.epoch,
previous_epoch,
current_epoch,
)?;
let validator_participation = epoch_participation
.get_mut(index)
.ok_or(BeaconStateError::ParticipationOutOfBounds(index))?;
Expand All @@ -213,7 +215,8 @@ impl<T: BeaconChainTypes> BeaconChain<T> {
{
validator_participation.add_flag(flag_index)?;
proposer_reward_numerator.safe_add_assign(
ctxt.get_base_reward(state, index, &self.spec)
state
.get_base_reward(index)
.map_err(|_| BeaconChainError::BlockRewardAttestationError)?
.safe_mul(weight)?,
)?;
Expand Down
39 changes: 9 additions & 30 deletions consensus/state_processing/src/consensus_context.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::common::get_indexed_attestation;
use crate::per_block_processing::errors::{AttestationInvalid, BlockOperationError};
use crate::{EpochCache, EpochCacheError};
use std::borrow::Cow;
use crate::EpochCacheError;
use std::collections::{hash_map::Entry, HashMap};
use std::marker::PhantomData;
use tree_hash::TreeHash;
Expand All @@ -14,12 +13,14 @@ use types::{
pub struct ConsensusContext<T: EthSpec> {
/// Slot to act as an identifier/safeguard
slot: Slot,
/// Previous epoch of the `slot` precomputed for optimization purpose.
pub(crate) previous_epoch: Epoch,
/// Current epoch of the `slot` precomputed for optimization purpose.
pub(crate) current_epoch: Epoch,
/// Proposer index of the block at `slot`.
proposer_index: Option<u64>,
/// Block root of the block at `slot`.
current_block_root: Option<Hash256>,
/// Epoch cache of values that are useful for block processing that are static over an epoch.
epoch_cache: Option<EpochCache>,
/// Cache of indexed attestations constructed during block processing.
indexed_attestations:
HashMap<(AttestationData, BitList<T::MaxValidatorsPerCommittee>), IndexedAttestation<T>>,
Expand Down Expand Up @@ -48,11 +49,14 @@ impl From<EpochCacheError> for ContextError {

impl<T: EthSpec> ConsensusContext<T> {
pub fn new(slot: Slot) -> Self {
let current_epoch = slot.epoch(T::slots_per_epoch());
let previous_epoch = current_epoch.saturating_sub(1u64);
Self {
slot,
previous_epoch,
current_epoch,
proposer_index: None,
current_block_root: None,
epoch_cache: None,
indexed_attestations: HashMap::new(),
_phantom: PhantomData,
}
Expand Down Expand Up @@ -145,31 +149,6 @@ impl<T: EthSpec> ConsensusContext<T> {
}
}

pub fn set_epoch_cache(mut self, epoch_cache: EpochCache) -> Self {
self.epoch_cache = Some(epoch_cache);
self
}

pub fn get_base_reward(
&mut self,
state: &BeaconState<T>,
validator_index: usize,
spec: &ChainSpec,
) -> Result<u64, ContextError> {
self.check_slot(state.slot())?;

// Build epoch cache if not already built.
let epoch_cache = if let Some(ref cache) = self.epoch_cache {
Cow::Borrowed(cache)
} else {
let cache = EpochCache::new(state, spec)?;
self.epoch_cache = Some(cache.clone());
Cow::Owned(cache)
};

Ok(epoch_cache.get_base_reward(validator_index)?)
}

pub fn get_indexed_attestation(
&mut self,
state: &BeaconState<T>,
Expand Down
180 changes: 49 additions & 131 deletions consensus/state_processing/src/epoch_cache.rs
Original file line number Diff line number Diff line change
@@ -1,137 +1,55 @@
use crate::common::{
altair::{self, BaseRewardPerIncrement},
base::{self, SqrtTotalActiveBalance},
};
use safe_arith::ArithError;
use std::sync::Arc;
use types::{BeaconState, BeaconStateError, ChainSpec, Epoch, EthSpec, Hash256, Slot};

/// Cache of values which are uniquely determined at the start of an epoch.
///
/// The values are fixed with respect to the last block of the _prior_ epoch, which we refer
/// to as the "decision block". This cache is very similar to the `BeaconProposerCache` in that
/// beacon proposers are determined at exactly the same time as the values in this cache, so
/// the keys for the two caches are identical.
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct EpochCache {
inner: Arc<Inner>,
}

#[derive(Debug, PartialEq, Eq, Clone)]
struct Inner {
/// Unique identifier for this cache, which can be used to check its validity before use
/// with any `BeaconState`.
key: EpochCacheKey,
/// Base reward for every validator in this epoch.
base_rewards: Vec<u64>,
}

#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub struct EpochCacheKey {
pub epoch: Epoch,
pub decision_block_root: Hash256,
}

#[derive(Debug, PartialEq, Clone)]
pub enum EpochCacheError {
IncorrectEpoch { cache: Epoch, state: Epoch },
IncorrectDecisionBlock { cache: Hash256, state: Hash256 },
ValidatorIndexOutOfBounds { validator_index: usize },
InvalidSlot { slot: Slot },
Arith(ArithError),
BeaconState(BeaconStateError),
}

impl From<BeaconStateError> for EpochCacheError {
fn from(e: BeaconStateError) -> Self {
Self::BeaconState(e)
use crate::common::altair::BaseRewardPerIncrement;
use crate::common::base::SqrtTotalActiveBalance;
use crate::common::{altair, base};
use types::epoch_cache::{EpochCache, EpochCacheError, EpochCacheKey};
use types::{BeaconState, ChainSpec, Epoch, EthSpec, Hash256};

pub fn initialize_epoch_cache<E: EthSpec>(
state: &mut BeaconState<E>,
epoch: Epoch,
spec: &ChainSpec,
) -> Result<(), EpochCacheError> {
let epoch_cache: &EpochCache = state.epoch_cache();
let decision_block_root = state
.proposer_shuffling_decision_root(Hash256::zero())
.map_err(EpochCacheError::BeaconState)?;

if epoch_cache
.check_validity::<E>(epoch, decision_block_root)
.is_ok()
{
// `EpochCache` has already been initialized and is valid, no need to initialize.
return Ok(());
}
}

impl From<ArithError> for EpochCacheError {
fn from(e: ArithError) -> Self {
Self::Arith(e)
// Compute base rewards.
let total_active_balance = state.get_total_active_balance_at_epoch(epoch)?;
let sqrt_total_active_balance = SqrtTotalActiveBalance::new(total_active_balance);
let base_reward_per_increment = BaseRewardPerIncrement::new(total_active_balance, spec)?;

let mut base_rewards = Vec::with_capacity(state.validators().len());

for validator in state.validators().iter() {
let effective_balance = validator.effective_balance();

let base_reward = if spec
.altair_fork_epoch
.map_or(false, |altair_epoch| epoch < altair_epoch)
{
base::get_base_reward(effective_balance, sqrt_total_active_balance, spec)?
} else {
altair::get_base_reward(effective_balance, base_reward_per_increment, spec)?
};
base_rewards.push(base_reward);
}
}

impl EpochCache {
pub fn new<E: EthSpec>(
state: &BeaconState<E>,
spec: &ChainSpec,
) -> Result<Self, EpochCacheError> {
let epoch = state.current_epoch();
let decision_block_root = state
.proposer_shuffling_decision_root(Hash256::zero())
.map_err(EpochCacheError::BeaconState)?;
*state.epoch_cache_mut() = EpochCache::new(
EpochCacheKey {
epoch,
decision_block_root,
},
base_rewards,
);

// The cache should never be constructed at slot 0 because it should only be used for
// block processing (which implies slot > 0) or epoch processing (which implies slot >= 32).
/* FIXME(sproul): EF tests like this
if decision_block_root.is_zero() {
return Err(EpochCacheError::InvalidSlot { slot: state.slot() });
}
*/

// Compute base rewards.
let total_active_balance = state.get_total_active_balance()?;
let sqrt_total_active_balance = SqrtTotalActiveBalance::new(total_active_balance);
let base_reward_per_increment = BaseRewardPerIncrement::new(total_active_balance, spec)?;

let mut base_rewards = Vec::with_capacity(state.validators().len());

for validator in state.validators().iter() {
let effective_balance = validator.effective_balance();

let base_reward = if spec
.altair_fork_epoch
.map_or(false, |altair_epoch| epoch < altair_epoch)
{
base::get_base_reward(effective_balance, sqrt_total_active_balance, spec)?
} else {
altair::get_base_reward(effective_balance, base_reward_per_increment, spec)?
};
base_rewards.push(base_reward);
}

Ok(Self {
inner: Arc::new(Inner {
key: EpochCacheKey {
epoch,
decision_block_root,
},
base_rewards,
}),
})
}

pub fn check_validity<E: EthSpec>(
&self,
state: &BeaconState<E>,
) -> Result<(), EpochCacheError> {
if self.inner.key.epoch != state.current_epoch() {
return Err(EpochCacheError::IncorrectEpoch {
cache: self.inner.key.epoch,
state: state.current_epoch(),
});
}
let state_decision_root = state
.proposer_shuffling_decision_root(Hash256::zero())
.map_err(EpochCacheError::BeaconState)?;
if self.inner.key.decision_block_root != state_decision_root {
return Err(EpochCacheError::IncorrectDecisionBlock {
cache: self.inner.key.decision_block_root,
state: state_decision_root,
});
}
Ok(())
}

#[inline]
pub fn get_base_reward(&self, validator_index: usize) -> Result<u64, EpochCacheError> {
self.inner
.base_rewards
.get(validator_index)
.copied()
.ok_or(EpochCacheError::ValidatorIndexOutOfBounds { validator_index })
}
Ok(())
}
2 changes: 1 addition & 1 deletion consensus/state_processing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ pub mod verify_operation;

pub use block_replayer::{BlockReplayError, BlockReplayer, StateProcessingStrategy};
pub use consensus_context::{ConsensusContext, ContextError};
pub use epoch_cache::{EpochCache, EpochCacheError, EpochCacheKey};
pub use genesis::{
eth2_genesis_time, initialize_beacon_state_from_eth1, is_valid_genesis_state,
process_activations,
Expand All @@ -43,4 +42,5 @@ pub use per_epoch_processing::{
errors::EpochProcessingError, process_epoch as per_epoch_processing,
};
pub use per_slot_processing::{per_slot_processing, Error as SlotProcessingError};
pub use types::{EpochCache, EpochCacheError, EpochCacheKey};
pub use verify_operation::{SigVerifiedOp, VerifyOperation, VerifyOperationAt};
4 changes: 4 additions & 0 deletions consensus/state_processing/src/per_block_processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ mod verify_proposer_slashing;
use crate::common::decrease_balance;
use crate::StateProcessingStrategy;

use crate::epoch_cache::initialize_epoch_cache;
#[cfg(feature = "arbitrary-fuzz")]
use arbitrary::Arbitrary;

Expand Down Expand Up @@ -114,6 +115,9 @@ pub fn per_block_processing<T: EthSpec, Payload: AbstractExecPayload<T>>(
.fork_name(spec)
.map_err(BlockProcessingError::InconsistentStateFork)?;

// Build epoch cache if it hasn't already been built, or if it is no longer valid
initialize_epoch_cache(state, state.current_epoch(), spec)?;

let verify_signatures = match block_signature_strategy {
BlockSignatureStrategy::VerifyBulk => {
// Verify all signatures in the block at once.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ pub mod altair {
})
}

#[allow(clippy::too_many_arguments)]
pub fn process_attestation<T: EthSpec>(
state: &mut BeaconState<T>,
attestation: &Attestation<T>,
Expand Down Expand Up @@ -149,18 +150,22 @@ pub mod altair {
let index = *index as usize;

for (flag_index, &weight) in PARTICIPATION_FLAG_WEIGHTS.iter().enumerate() {
let epoch_participation = state.get_epoch_participation_mut(data.target.epoch)?;
let validator_participation = epoch_participation
.get_mut(index)
.ok_or(BeaconStateError::ParticipationOutOfBounds(index))?;

if participation_flag_indices.contains(&flag_index)
&& !validator_participation.has_flag(flag_index)?
{
validator_participation.add_flag(flag_index)?;
proposer_reward_numerator.safe_add_assign(
ctxt.get_base_reward(state, index, spec)?.safe_mul(weight)?,
)?;
let epoch_participation = state.get_epoch_participation_mut(
data.target.epoch,
ctxt.previous_epoch,
ctxt.current_epoch,
)?;

if participation_flag_indices.contains(&flag_index) {
let validator_participation = epoch_participation
.get_mut(index)
.ok_or(BeaconStateError::ParticipationOutOfBounds(index))?;

if !validator_participation.has_flag(flag_index)? {
validator_participation.add_flag(flag_index)?;
proposer_reward_numerator
.safe_add_assign(state.get_base_reward(index)?.safe_mul(weight)?)?;
}
}
}
}
Expand Down
Loading

0 comments on commit 2df714e

Please sign in to comment.