Skip to content

Commit 4bde8a0

Browse files
Performance improvements for shuffle and partial_shuffle (#1272)
* Made shuffle and partial_shuffle faster * Use criterion benchmarks for shuffle * Added a note about RNG word size * Tidied comments * Added a debug_assert * Added a comment re possible further optimization * Added and updated copyright notices * Revert cfg mistake * Reverted change to mod.rs * Removed ChaCha20 benches from shuffle * moved debug_assert out of a const fn
1 parent 1e96eb4 commit 4bde8a0

File tree

6 files changed

+205
-19
lines changed

6 files changed

+205
-19
lines changed

Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,9 @@ criterion = { version = "0.4" }
7979
[[bench]]
8080
name = "seq_choose"
8181
path = "benches/seq_choose.rs"
82+
harness = false
83+
84+
[[bench]]
85+
name = "shuffle"
86+
path = "benches/shuffle.rs"
8287
harness = false

benches/seq_choose.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018-2022 Developers of the Rand project.
1+
// Copyright 2018-2023 Developers of the Rand project.
22
//
33
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
44
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license

benches/shuffle.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright 2018-2023 Developers of the Rand project.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
use criterion::{black_box, criterion_group, criterion_main, Criterion};
9+
use rand::prelude::*;
10+
use rand::SeedableRng;
11+
12+
criterion_group!(
13+
name = benches;
14+
config = Criterion::default();
15+
targets = bench
16+
);
17+
criterion_main!(benches);
18+
19+
pub fn bench(c: &mut Criterion) {
20+
bench_rng::<rand_chacha::ChaCha12Rng>(c, "ChaCha12");
21+
bench_rng::<rand_pcg::Pcg32>(c, "Pcg32");
22+
bench_rng::<rand_pcg::Pcg64>(c, "Pcg64");
23+
}
24+
25+
fn bench_rng<Rng: RngCore + SeedableRng>(c: &mut Criterion, rng_name: &'static str) {
26+
for length in [1, 2, 3, 10, 100, 1000, 10000].map(|x| black_box(x)) {
27+
c.bench_function(format!("shuffle_{length}_{rng_name}").as_str(), |b| {
28+
let mut rng = Rng::seed_from_u64(123);
29+
let mut vec: Vec<usize> = (0..length).collect();
30+
b.iter(|| {
31+
vec.shuffle(&mut rng);
32+
vec[0]
33+
})
34+
});
35+
36+
if length >= 10 {
37+
c.bench_function(
38+
format!("partial_shuffle_{length}_{rng_name}").as_str(),
39+
|b| {
40+
let mut rng = Rng::seed_from_u64(123);
41+
let mut vec: Vec<usize> = (0..length).collect();
42+
b.iter(|| {
43+
vec.partial_shuffle(&mut rng, length / 2);
44+
vec[0]
45+
})
46+
},
47+
);
48+
}
49+
}
50+
}

src/seq/coin_flipper.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
// Copyright 2018-2023 Developers of the Rand project.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
19
use crate::RngCore;
210

311
pub(crate) struct CoinFlipper<R: RngCore> {

src/seq/increasing_uniform.rs

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright 2018-2023 Developers of the Rand project.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use crate::{Rng, RngCore};
10+
11+
/// Similar to a Uniform distribution,
12+
/// but after returning a number in the range [0,n], n is increased by 1.
13+
pub(crate) struct IncreasingUniform<R: RngCore> {
14+
pub rng: R,
15+
n: u32,
16+
// Chunk is a random number in [0, (n + 1) * (n + 2) *..* (n + chunk_remaining) )
17+
chunk: u32,
18+
chunk_remaining: u8,
19+
}
20+
21+
impl<R: RngCore> IncreasingUniform<R> {
22+
/// Create a dice roller.
23+
/// The next item returned will be a random number in the range [0,n]
24+
pub fn new(rng: R, n: u32) -> Self {
25+
// If n = 0, the first number returned will always be 0
26+
// so we don't need to generate a random number
27+
let chunk_remaining = if n == 0 { 1 } else { 0 };
28+
Self {
29+
rng,
30+
n,
31+
chunk: 0,
32+
chunk_remaining,
33+
}
34+
}
35+
36+
/// Returns a number in [0,n] and increments n by 1.
37+
/// Generates new random bits as needed
38+
/// Panics if `n >= u32::MAX`
39+
#[inline]
40+
pub fn next_index(&mut self) -> usize {
41+
let next_n = self.n + 1;
42+
43+
// There's room for further optimisation here:
44+
// gen_range uses rejection sampling (or other method; see #1196) to avoid bias.
45+
// When the initial sample is biased for range 0..bound
46+
// it may still be viable to use for a smaller bound
47+
// (especially if small biases are considered acceptable).
48+
49+
let next_chunk_remaining = self.chunk_remaining.checked_sub(1).unwrap_or_else(|| {
50+
// If the chunk is empty, generate a new chunk
51+
let (bound, remaining) = calculate_bound_u32(next_n);
52+
// bound = (n + 1) * (n + 2) *..* (n + remaining)
53+
self.chunk = self.rng.gen_range(0..bound);
54+
// Chunk is a random number in
55+
// [0, (n + 1) * (n + 2) *..* (n + remaining) )
56+
57+
remaining - 1
58+
});
59+
60+
let result = if next_chunk_remaining == 0 {
61+
// `chunk` is a random number in the range [0..n+1)
62+
// Because `chunk_remaining` is about to be set to zero
63+
// we do not need to clear the chunk here
64+
self.chunk as usize
65+
} else {
66+
// `chunk` is a random number in a range that is a multiple of n+1
67+
// so r will be a random number in [0..n+1)
68+
let r = self.chunk % next_n;
69+
self.chunk /= next_n;
70+
r as usize
71+
};
72+
73+
self.chunk_remaining = next_chunk_remaining;
74+
self.n = next_n;
75+
result
76+
}
77+
}
78+
79+
#[inline]
80+
/// Calculates `bound`, `count` such that bound (m)*(m+1)*..*(m + remaining - 1)
81+
fn calculate_bound_u32(m: u32) -> (u32, u8) {
82+
debug_assert!(m > 0);
83+
#[inline]
84+
const fn inner(m: u32) -> (u32, u8) {
85+
let mut product = m;
86+
let mut current = m + 1;
87+
88+
loop {
89+
if let Some(p) = u32::checked_mul(product, current) {
90+
product = p;
91+
current += 1;
92+
} else {
93+
// Count has a maximum value of 13 for when min is 1 or 2
94+
let count = (current - m) as u8;
95+
return (product, count);
96+
}
97+
}
98+
}
99+
100+
const RESULT2: (u32, u8) = inner(2);
101+
if m == 2 {
102+
// Making this value a constant instead of recalculating it
103+
// gives a significant (~50%) performance boost for small shuffles
104+
return RESULT2;
105+
}
106+
107+
inner(m)
108+
}

src/seq/mod.rs

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018 Developers of the Rand project.
1+
// Copyright 2018-2023 Developers of the Rand project.
22
//
33
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
44
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
@@ -29,6 +29,8 @@ mod coin_flipper;
2929
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
3030
pub mod index;
3131

32+
mod increasing_uniform;
33+
3234
#[cfg(feature = "alloc")]
3335
use core::ops::Index;
3436

@@ -42,6 +44,7 @@ use crate::distributions::WeightedError;
4244
use crate::Rng;
4345

4446
use self::coin_flipper::CoinFlipper;
47+
use self::increasing_uniform::IncreasingUniform;
4548

4649
/// Extension trait on slices, providing random mutation and sampling methods.
4750
///
@@ -620,10 +623,11 @@ impl<T> SliceRandom for [T] {
620623
where
621624
R: Rng + ?Sized,
622625
{
623-
for i in (1..self.len()).rev() {
624-
// invariant: elements with index > i have been locked in place.
625-
self.swap(i, gen_index(rng, i + 1));
626+
if self.len() <= 1 {
627+
// There is no need to shuffle an empty or single element slice
628+
return;
626629
}
630+
self.partial_shuffle(rng, self.len());
627631
}
628632

629633
fn partial_shuffle<R>(
@@ -632,19 +636,30 @@ impl<T> SliceRandom for [T] {
632636
where
633637
R: Rng + ?Sized,
634638
{
635-
// This applies Durstenfeld's algorithm for the
636-
// [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm)
637-
// for an unbiased permutation, but exits early after choosing `amount`
638-
// elements.
639-
640-
let len = self.len();
641-
let end = if amount >= len { 0 } else { len - amount };
639+
let m = self.len().saturating_sub(amount);
642640

643-
for i in (end..len).rev() {
644-
// invariant: elements with index > i have been locked in place.
645-
self.swap(i, gen_index(rng, i + 1));
641+
// The algorithm below is based on Durstenfeld's algorithm for the
642+
// [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm)
643+
// for an unbiased permutation.
644+
// It ensures that the last `amount` elements of the slice
645+
// are randomly selected from the whole slice.
646+
647+
//`IncreasingUniform::next_index()` is faster than `gen_index`
648+
//but only works for 32 bit integers
649+
//So we must use the slow method if the slice is longer than that.
650+
if self.len() < (u32::MAX as usize) {
651+
let mut chooser = IncreasingUniform::new(rng, m as u32);
652+
for i in m..self.len() {
653+
let index = chooser.next_index();
654+
self.swap(i, index);
655+
}
656+
} else {
657+
for i in m..self.len() {
658+
let index = gen_index(rng, i + 1);
659+
self.swap(i, index);
660+
}
646661
}
647-
let r = self.split_at_mut(end);
662+
let r = self.split_at_mut(m);
648663
(r.1, r.0)
649664
}
650665
}
@@ -765,11 +780,11 @@ mod test {
765780

766781
let mut r = crate::test::rng(414);
767782
nums.shuffle(&mut r);
768-
assert_eq!(nums, [9, 5, 3, 10, 7, 12, 8, 11, 6, 4, 0, 2, 1]);
783+
assert_eq!(nums, [5, 11, 0, 8, 7, 12, 6, 4, 9, 3, 1, 2, 10]);
769784
nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
770785
let res = nums.partial_shuffle(&mut r, 6);
771-
assert_eq!(res.0, &mut [7, 4, 8, 6, 9, 3]);
772-
assert_eq!(res.1, &mut [0, 1, 2, 12, 11, 5, 10]);
786+
assert_eq!(res.0, &mut [7, 12, 6, 8, 1, 9]);
787+
assert_eq!(res.1, &mut [0, 11, 2, 3, 4, 5, 10]);
773788
}
774789

775790
#[derive(Clone)]

0 commit comments

Comments
 (0)