Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new versions of choose and choose_stable #1268

Merged
merged 29 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bd6b9c9
Added new versions of choose and choose_stable
wainwrightmark Nov 17, 2022
eb5672b
Removed coin_flipper tests which were unnecessary and not building on ci
wainwrightmark Nov 17, 2022
ecb1158
Performance optimizations in coin_flipper
wainwrightmark Nov 18, 2022
6a0d278
Clippy fixes and more documentation
wainwrightmark Nov 18, 2022
b9c0b20
Added a correctness fix for coin_flipper
wainwrightmark Nov 18, 2022
0ce6bfa
Update benches/seq.rs
wainwrightmark Nov 21, 2022
7f34c55
Update benches/seq.rs
wainwrightmark Nov 21, 2022
1fe6c9f
Removed old version of choose and choose stable and updated value sta…
wainwrightmark Nov 21, 2022
0209e41
Merge branch 'master' of https://github.com/wainwrightmark/rand
wainwrightmark Nov 21, 2022
79f6953
Moved sequence choose benchmarks to their own file
wainwrightmark Nov 21, 2022
b5312f4
Reworked coin_flipper
wainwrightmark Dec 5, 2022
2339539
Use criterion for seq_choose benches
wainwrightmark Dec 5, 2022
309959c
Removed an old comment
wainwrightmark Dec 5, 2022
2a2f434
Change how c is estimated in coin_flipper
wainwrightmark Dec 5, 2022
b3fdc3f
Revert "Use criterion for seq_choose benches"
wainwrightmark Dec 5, 2022
b3062e5
Added seq_choose benches for smaller numbers
wainwrightmark Dec 5, 2022
a7a7a90
Removed some unneeded lines from seq_choose
wainwrightmark Dec 5, 2022
4726a60
Improvements in coin_flipper.rs
wainwrightmark Dec 9, 2022
68dc604
Small refactor of coin_flipper
wainwrightmark Dec 9, 2022
8723cda
Tidied comments in coin_flipper
wainwrightmark Dec 9, 2022
a2c4cce
Use criterion for seq_choose benchmarks
wainwrightmark Dec 9, 2022
9a798aa
Merge branch 'rust-random:master' into master
wainwrightmark Dec 9, 2022
8601cd6
Made choose not generate a random number if len=1
wainwrightmark Dec 9, 2022
03a7d8b
Merge branch 'master' of https://github.com/rust-random/rand into rus…
wainwrightmark Dec 13, 2022
f6e7fec
small change to IteratorRandom::choose
wainwrightmark Jan 4, 2023
999104c
Made it easier to change seq_choose benchmarks RNG
wainwrightmark Jan 4, 2023
8816449
Added Pcg64 benchmarks for seq_choose
wainwrightmark Jan 4, 2023
d76ddb7
Added TODO to coin_flipper
wainwrightmark Jan 4, 2023
a9aade6
Changed criterion settings in seq_choose
wainwrightmark Jan 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 1 addition & 71 deletions benches/seq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ extern crate test;

use test::Bencher;

use core::mem::size_of;
use rand::prelude::*;
use rand::seq::*;
use core::mem::size_of;

// We force use of 32-bit RNG since seq code is optimised for use with 32-bit
// generators on all platforms.
Expand Down Expand Up @@ -74,76 +74,6 @@ seq_slice_choose_multiple!(seq_slice_choose_multiple_950_of_1000, 950, 1000);
seq_slice_choose_multiple!(seq_slice_choose_multiple_10_of_100, 10, 100);
seq_slice_choose_multiple!(seq_slice_choose_multiple_90_of_100, 90, 100);

#[bench]
fn seq_iter_choose_from_1000(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let x: &mut [usize] = &mut [1; 1000];
for (i, r) in x.iter_mut().enumerate() {
*r = i;
}
b.iter(|| {
let mut s = 0;
for _ in 0..RAND_BENCH_N {
s += x.iter().choose(&mut rng).unwrap();
}
s
});
b.bytes = size_of::<usize>() as u64 * crate::RAND_BENCH_N;
}

#[derive(Clone)]
struct UnhintedIterator<I: Iterator + Clone> {
iter: I,
}
impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
type Item = I::Item;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}

#[derive(Clone)]
struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
iter: I,
window_size: usize,
}
impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
type Item = I::Item;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}

fn size_hint(&self) -> (usize, Option<usize>) {
(core::cmp::min(self.iter.len(), self.window_size), None)
}
}

#[bench]
fn seq_iter_unhinted_choose_from_1000(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let x: &[usize] = &[1; 1000];
b.iter(|| {
UnhintedIterator { iter: x.iter() }
.choose(&mut rng)
.unwrap()
})
}

#[bench]
fn seq_iter_window_hinted_choose_from_1000(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let x: &[usize] = &[1; 1000];
b.iter(|| {
WindowHintedIterator {
iter: x.iter(),
window_size: 7,
}
.choose(&mut rng)
})
}

#[bench]
fn seq_iter_choose_multiple_10_of_100(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
Expand Down
163 changes: 163 additions & 0 deletions benches/seq_choose.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
// Copyright 2018 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

#![feature(test)]
#![allow(non_snake_case)]
#![feature(custom_inner_attributes)]
// Rustfmt splits macro invocations to shorten lines; in this case longer-lines are more readable
#![rustfmt::skip]

extern crate test;

use test::Bencher;
use rand::prelude::*;

// We force use of 32-bit RNG since seq code is optimised for use with 32-bit
// generators on all platforms.
use rand_chacha::ChaCha20Rng as CryptoRng;
use rand_pcg::Pcg32 as SmallRng;

const RAND_BENCH_N: u64 = 1000;

#[derive(Clone)]
struct UnhintedIterator<I: Iterator + Clone> {
iter: I, }
impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
type Item = I::Item;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}

#[derive(Clone)]
struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
iter: I, window_size: usize, }
impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
type Item = I::Item;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}

fn size_hint(&self) -> (usize, Option<usize>) {
(core::cmp::min(self.iter.len(), self.window_size), None)
}
}

macro_rules! bench_seq_iter_size_hinted {
($name:ident, $rng:ident, $fn:ident, $length:expr) => {
#[bench]
fn $name(b: &mut Bencher) {
let mut rng = $rng::from_rng(thread_rng()).unwrap();
let x: &mut [usize] = &mut [1; $length];
for (i, r) in x.iter_mut().enumerate() {
*r = i;
}
wainwrightmark marked this conversation as resolved.
Show resolved Hide resolved
b.iter(|| {
let mut s = 0;
for _ in 0..RAND_BENCH_N {
s += x.iter().$fn(&mut rng).unwrap();
}
s
});
}
};
}

macro_rules! bench_seq_iter_unhinted {
($name:ident,$rng:ident, $fn:ident, $length:expr) => {
#[bench]
fn $name(b: &mut Bencher) {
let mut rng = $rng::from_rng(thread_rng()).unwrap();
let x: &[usize] = &[1; $length];
b.iter(|| UnhintedIterator { iter: x.iter() }.$fn(&mut rng).unwrap())
}
};
}

macro_rules! bench_seq_iter_window_hinted {
($name:ident,$rng:ident, $fn:ident, $length:expr) => {
#[bench]
fn $name(b: &mut Bencher) {
let mut rng = $rng::from_rng(thread_rng()).unwrap();
let x: &[usize] = &[1; $length];
b.iter(|| {
WindowHintedIterator {
iter: x.iter(), window_size: 7, }
.$fn(&mut rng)
.unwrap()
})
}
};
}

//Size Hinted
bench_seq_iter_size_hinted!(seq_iter_size_hinted_choose_from_10000_cryptoRng, CryptoRng, choose, 10000);
bench_seq_iter_size_hinted!(seq_iter_size_hinted_choose_from_10000_smallRng, SmallRng, choose, 10000);
bench_seq_iter_size_hinted!(seq_iter_size_hinted_choose_from_1000_cryptoRng, CryptoRng, choose, 1000);
bench_seq_iter_size_hinted!(seq_iter_size_hinted_choose_from_1000_smallRng, SmallRng, choose, 1000);
bench_seq_iter_size_hinted!(seq_iter_size_hinted_choose_from_100_cryptoRng, CryptoRng, choose, 100);
bench_seq_iter_size_hinted!(seq_iter_size_hinted_choose_from_100_smallRng, SmallRng, choose, 100);
bench_seq_iter_size_hinted!(seq_iter_size_hinted_choose_from_10_smallRng, SmallRng, choose, 10);
bench_seq_iter_size_hinted!(seq_iter_size_hinted_choose_from_10_cryptoRng, CryptoRng, choose, 10);
bench_seq_iter_size_hinted!(seq_iter_size_hinted_choose_from_3_smallRng, SmallRng, choose, 3);
bench_seq_iter_size_hinted!(seq_iter_size_hinted_choose_from_3_cryptoRng, CryptoRng, choose, 3);
bench_seq_iter_size_hinted!(seq_iter_size_hinted_choose_from_2_smallRng, SmallRng, choose, 2);
bench_seq_iter_size_hinted!(seq_iter_size_hinted_choose_from_2_cryptoRng, CryptoRng, choose, 2);
bench_seq_iter_size_hinted!(seq_iter_size_hinted_choose_from_1_smallRng, SmallRng, choose, 1);
bench_seq_iter_size_hinted!(seq_iter_size_hinted_choose_from_1_cryptoRng, CryptoRng, choose, 1);

//Unhinted
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_from_10000_cryptoRng, CryptoRng, choose, 10000);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_from_10000_smallRng, SmallRng, choose, 10000);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_from_1000_cryptoRng, CryptoRng, choose, 1000);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_from_1000_smallRng, SmallRng, choose, 1000);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_from_100_cryptoRng, CryptoRng, choose, 100);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_from_100_smallRng, SmallRng, choose, 100);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_from_10_smallRng, SmallRng, choose, 10);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_from_10_cryptoRng, CryptoRng, choose, 10);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_from_3_smallRng, SmallRng, choose, 3);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_from_3_cryptoRng, CryptoRng, choose, 3);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_from_2_smallRng, SmallRng, choose, 2);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_from_2_cryptoRng, CryptoRng, choose, 2);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_from_1_smallRng, SmallRng, choose, 1);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_from_1_cryptoRng, CryptoRng, choose, 1);

// Window hinted
bench_seq_iter_window_hinted!(seq_iter_window_hinted_choose_from_10000_cryptoRng, CryptoRng, choose, 10000);
bench_seq_iter_window_hinted!(seq_iter_window_hinted_choose_from_10000_smallRng, SmallRng, choose, 10000);
bench_seq_iter_window_hinted!(seq_iter_window_hinted_choose_from_1000_cryptoRng, CryptoRng, choose, 1000);
bench_seq_iter_window_hinted!(seq_iter_window_hinted_choose_from_1000_smallRng, SmallRng, choose, 1000);
bench_seq_iter_window_hinted!(seq_iter_window_hinted_choose_from_100_cryptoRng, CryptoRng, choose, 100);
bench_seq_iter_window_hinted!(seq_iter_window_hinted_choose_from_100_smallRng, SmallRng, choose, 100);
bench_seq_iter_window_hinted!(seq_iter_window_hinted_choose_from_10_smallRng, SmallRng, choose, 10);
bench_seq_iter_window_hinted!(seq_iter_window_hinted_choose_from_10_cryptoRng, CryptoRng, choose, 10);
bench_seq_iter_window_hinted!(seq_iter_window_hinted_choose_from_3_smallRng, SmallRng, choose, 3);
bench_seq_iter_window_hinted!(seq_iter_window_hinted_choose_from_3_cryptoRng, CryptoRng, choose, 3);
bench_seq_iter_window_hinted!(seq_iter_window_hinted_choose_from_2_smallRng, SmallRng, choose, 2);
bench_seq_iter_window_hinted!(seq_iter_window_hinted_choose_from_2_cryptoRng, CryptoRng, choose, 2);
bench_seq_iter_window_hinted!(seq_iter_window_hinted_choose_from_1_smallRng, SmallRng, choose, 1);
bench_seq_iter_window_hinted!(seq_iter_window_hinted_choose_from_1_cryptoRng, CryptoRng, choose, 1);

//Choose Stable
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_stable_from_10000_smallRng, SmallRng, choose_stable, 10000);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_stable_from_10000_cryptoRng, CryptoRng, choose_stable, 10000);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_stable_from_1000_smallRng, SmallRng, choose_stable, 1000);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_stable_from_1000_cryptoRng, CryptoRng, choose_stable, 1000);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_stable_from_100_smallRng, SmallRng, choose_stable, 100);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_stable_from_100_cryptoRng, CryptoRng, choose_stable, 100);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_stable_from_10_smallRng, SmallRng, choose_stable, 10);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_stable_from_10_cryptoRng, CryptoRng, choose_stable, 10);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_stable_from_3_smallRng, SmallRng, choose_stable, 3);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_stable_from_3_cryptoRng, CryptoRng, choose_stable, 3);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_stable_from_2_smallRng, SmallRng, choose_stable, 2);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_stable_from_2_cryptoRng, CryptoRng, choose_stable, 2);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_stable_from_1_smallRng, SmallRng, choose_stable, 1);
bench_seq_iter_unhinted!(seq_iter_unhinted_choose_stable_from_1_cryptoRng, CryptoRng, choose_stable, 1);

138 changes: 138 additions & 0 deletions src/seq/coin_flipper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
use crate::RngCore;

pub(crate) struct CoinFlipper<R: RngCore> {
pub rng: R,
chunk: u32,
chunk_remaining: u32,
}

impl<R: RngCore> CoinFlipper<R> {
pub fn new(rng: R) -> Self {
Self {
rng,
chunk: 0,
chunk_remaining: 0,
}
}

#[inline]
/// Returns true with a probability of 1 / d
/// Uses an expected two bits of randomness
pub fn gen_ratio_one_over(&mut self, d: usize) -> bool {
// This uses the same logic as `gen_ratio` but is optimized for the case that the starting numerator is one (which it always is for `Sequence::Choose()`)
wainwrightmark marked this conversation as resolved.
Show resolved Hide resolved

// In this case (unlike in `gen_ratio`), this way of calculating c is always accurate
let c = (usize::BITS - 1 - d.leading_zeros()).min(32);
wainwrightmark marked this conversation as resolved.
Show resolved Hide resolved

if self.flip_until_tails(c) {
let numerator = 1 << c;
return self.gen_ratio(numerator, d);
} else {
return false;
}
}

#[inline]
/// Returns true with a probability of n / d
/// Uses an expected two bits of randomness
fn gen_ratio(&mut self, mut n: usize, d: usize) -> bool {
// Explanation:
// We are trying to return true with a probability of n / d
// If n >= d, we can just return true
// Otherwise there are two possibilities 2n < d and 2n >= d
// In either case we flip a coin.
// If 2n < d
// If it comes up tails, return false
// If it comes up heads, double n and start again
// This is fair because (0.5 * 0) + (0.5 * 2n / d) = n / d and 2n is less than d (if 2n was greater than d we would effectively round it down to 1 by returning true)
// If 2n >= d
// If it comes up tails, set n to 2n - d and start again
// If it comes up heads, return true
// This is fair because (0.5 * 1) + (0.5 * (2n - d) / d) = n / d
// Note that if 2n = d and the coin comes up tails, n will be set to 0 before restarting which is equivalent to returning false.

// As a performance optimization we can flip multiple coins at once (using the `lzcnt` intrinsic)
// We can check up to 32 flips at once but we only receive one bit of information - all heads or at least one tail.
// Let c be the number of coins to flip. 1 <= c <= 32
// If 2n < d, n * 2^c < d
// If the result is all heads, then set n to n * 2^c
// If there was at least one tail, return false
// If 2n >= d, the order of the heads and tails matters so we flip one coin at a time so c = 1
// Ideally, c will be as high as possible within these constraints

while n < d {
//Find a good value for c by counting leading zeros
//This will either give the highest possible c, or 1 less than that
let c = n.leading_zeros().saturating_sub(d.leading_zeros() + 1).min(32).max(1);
wainwrightmark marked this conversation as resolved.
Show resolved Hide resolved

// set next_n to n * 2^c (checked_shl will fail if 2n >= `usize::max`)
if let Some(next_n) = n.checked_shl(c) {
if self.flip_until_tails(c) {
//All heads
//if 2n < d, set n to 2n
//if 2n >= d, the while loop will exit and we will return `true`
n = next_n
} else {
//At least one tail - either return false or set n to 2n-d
n = next_n.saturating_sub(d);

if n == 0 {
//Because we used saturating_sub, n will be zero if 2n was less than d or 2n was equal to d, in either case we can return false.
return false;
}
}
} else {
// This branch will only be reached when 2n >= `usize::max`
// Obviously 2n > d
if self.flip_until_tails(1) {
//heads
return true;
} else {
//tails
n = n.saturating_add(n).saturating_sub(d); // set n to 2n -d
wainwrightmark marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
true
}

/// If the next `c` bits of randomness are all zeroes, consume them and return true.
/// Otherwise return false and consume the number of zeroes plus one
/// Generates new bits of randomness when necessary (int 32 bit chunks)
/// Has a one in 2 to the `c` chance of returning true
/// `c` must be less than or equal to 32
fn flip_until_tails(&mut self, mut c: u32) -> bool {
wainwrightmark marked this conversation as resolved.
Show resolved Hide resolved
debug_assert!(c <= 32); //If `c` > 32 this wil always return false
//Note that zeros on the left of the chunk represent heads. It needs to be this way round because zeros are filled in when left shifting
wainwrightmark marked this conversation as resolved.
Show resolved Hide resolved
loop {
let zeros = self.chunk.leading_zeros();

if zeros < c {
// The happy path - we found a 1 and can return false
// Note that because a 1 bit was detected, we cannot have run out of random bits so we don't need to check

// First consume all of the bits read
self.chunk = self.chunk.wrapping_shl(zeros + 1);
dhardy marked this conversation as resolved.
Show resolved Hide resolved
self.chunk_remaining = self.chunk_remaining.saturating_sub(zeros + 1);
return false;
} else {
// The number of zeros is larger than `c`
//There are two possibilities
if let Some(new_remaining) = self.chunk_remaining.checked_sub(c) {
//Those zeroes were all part of our random chunk, so throw away `c` bits of randomness and return true
self.chunk_remaining = new_remaining;
self.chunk <<= c;
return true;
} else {
// Some of those zeroes were part of the random chunk and some were part of the space behind it
c -= self.chunk_remaining; //Take into account the zeroes that were random

// Generate a new chunk
self.chunk = self.rng.next_u32();
self.chunk_remaining = 32;
wainwrightmark marked this conversation as resolved.
Show resolved Hide resolved
//Go back to start of loop
}
}
}
}
}
Loading