Skip to content

Commit

Permalink
fix: functions operating on party_weights are safe for sum of weights…
Browse files Browse the repository at this point in the history
… > u64::MAX
  • Loading branch information
NikitaMasych committed Oct 17, 2024
1 parent 9e4a65f commit 51e7219
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,19 @@ impl DefaultLeaderElector {
/// - `range`: The upper limit for the random value generation, typically the sum of party weights.
///
/// # Returns
/// A `u64` value within the specified range.
fn hash_to_range(seed: u64, range: u64) -> u64 {
/// A `u128` value within the specified range.
fn hash_to_range(seed: u64, range: u128) -> u128 {
// Determine the number of bits required to represent the range
let mut k = 64;
while 1u64 << (k - 1) > range {
let mut k = 128;
while 1u128 << (k - 1) > range {
k -= 1;
}

// Use a seeded random generator to produce a value within the desired range
let rng = Random::from_seed(Seed::unsafe_new(seed));
loop {
let mut raw_res: u64 = rng.gen();
raw_res >>= 64 - k;
let mut raw_res: u128 = rng.gen::<u128>(); // Generate a u128 random value
raw_res >>= 128 - k;

if raw_res < range {
return raw_res;
Expand Down Expand Up @@ -124,7 +124,7 @@ impl<V: Value, VS: ValueSelector<V>> LeaderElector<V, VS> for DefaultLeaderElect
fn elect_leader(&self, party: &Party<V, VS>) -> Result<u64, Box<dyn std::error::Error>> {
let seed = DefaultLeaderElector::compute_seed(party);

let total_weight: u64 = party.cfg.party_weights.iter().sum();
let total_weight: u128 = party.cfg.party_weights.iter().map(|&x| x as u128).sum();
if total_weight == 0 {
return Err(DefaultLeaderElectorError::ZeroWeightSum.into());
}
Expand All @@ -133,11 +133,11 @@ impl<V: Value, VS: ValueSelector<V>> LeaderElector<V, VS> for DefaultLeaderElect
let random_value = DefaultLeaderElector::hash_to_range(seed, total_weight);

// Use binary search to find the corresponding participant based on the cumulative weight
let mut cumulative_weights = vec![0; party.cfg.party_weights.len()];
cumulative_weights[0] = party.cfg.party_weights[0];
let mut cumulative_weights = vec![0u128; party.cfg.party_weights.len()];
cumulative_weights[0] = party.cfg.party_weights[0] as u128;

for i in 1..party.cfg.party_weights.len() {
cumulative_weights[i] = cumulative_weights[i - 1] + party.cfg.party_weights[i];
cumulative_weights[i] = cumulative_weights[i - 1] + party.cfg.party_weights[i] as u128;
}

match cumulative_weights.binary_search_by(|&weight| {
Expand Down
15 changes: 15 additions & 0 deletions tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,18 @@ async fn test_ballot_many_parties() {

analyze_ballot(results);
}

#[tokio::test]
async fn test_ballot_max_weight() {
let weights = vec![u64::MAX, 1];
let threshold = BPConConfig::compute_bft_threshold(weights.clone());
let cfg = BPConConfig::with_default_timeouts(weights, threshold);

let (parties, receivers, senders) = create_parties(cfg);
let ballot_tasks = launch_parties(parties);
let p2p_task = propagate_p2p(receivers, senders);
let results = await_results(ballot_tasks).await;
p2p_task.abort();

analyze_ballot(results);
}

0 comments on commit 51e7219

Please sign in to comment.