Skip to content

Commit 8616945

Browse files
authored
Merge pull request #866 from vks/update-weights
WeightedIndex: Make it possible to update a subset of weights
2 parents 29056a0 + c9428a0 commit 8616945

File tree

2 files changed

+153
-2
lines changed

2 files changed

+153
-2
lines changed

benches/weighted.rs

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright 2019 Developers of the Rand project.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
#![feature(test)]
10+
11+
extern crate test;
12+
13+
use test::Bencher;
14+
use rand::Rng;
15+
use rand::distributions::WeightedIndex;
16+
17+
#[bench]
18+
fn weighted_index_creation(b: &mut Bencher) {
19+
let mut rng = rand::thread_rng();
20+
let weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7];
21+
b.iter(|| {
22+
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
23+
rng.sample(distr)
24+
})
25+
}
26+
27+
#[bench]
28+
fn weighted_index_modification(b: &mut Bencher) {
29+
let mut rng = rand::thread_rng();
30+
let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
31+
let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
32+
b.iter(|| {
33+
distr.update_weights(&[(2, &4), (5, &1)]).unwrap();
34+
rng.sample(&distr)
35+
})
36+
}

src/distributions/weighted/mod.rs

+117-2
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ use core::fmt;
8484
#[derive(Debug, Clone)]
8585
pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
8686
cumulative_weights: Vec<X>,
87+
total_weight: X,
8788
weight_distribution: X::Sampler,
8889
}
8990

@@ -125,9 +126,98 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
125126
if total_weight == zero {
126127
return Err(WeightedError::AllWeightsZero);
127128
}
128-
let distr = X::Sampler::new(zero, total_weight);
129+
let distr = X::Sampler::new(zero, total_weight.clone());
129130

130-
Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr })
131+
Ok(WeightedIndex { cumulative_weights: weights, total_weight, weight_distribution: distr })
132+
}
133+
134+
/// Update a subset of weights, without changing the number of weights.
135+
///
136+
/// `new_weights` must be sorted by the index.
137+
///
138+
/// Using this method instead of `new` might be more efficient if only a small number of
139+
/// weights is modified. No allocations are performed, unless the weight type `X` uses
140+
/// allocation internally.
141+
///
142+
/// In case of error, `self` is not modified.
143+
pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError>
144+
where X: for<'a> ::core::ops::AddAssign<&'a X> +
145+
for<'a> ::core::ops::SubAssign<&'a X> +
146+
Clone +
147+
Default {
148+
if new_weights.is_empty() {
149+
return Ok(());
150+
}
151+
152+
let zero = <X as Default>::default();
153+
154+
let mut total_weight = self.total_weight.clone();
155+
156+
// Check for errors first, so we don't modify `self` in case something
157+
// goes wrong.
158+
let mut prev_i = None;
159+
for &(i, w) in new_weights {
160+
if let Some(old_i) = prev_i {
161+
if old_i >= i {
162+
return Err(WeightedError::InvalidWeight);
163+
}
164+
}
165+
if *w < zero {
166+
return Err(WeightedError::InvalidWeight);
167+
}
168+
if i >= self.cumulative_weights.len() + 1 {
169+
return Err(WeightedError::TooMany);
170+
}
171+
172+
let mut old_w = if i < self.cumulative_weights.len() {
173+
self.cumulative_weights[i].clone()
174+
} else {
175+
self.total_weight.clone()
176+
};
177+
if i > 0 {
178+
old_w -= &self.cumulative_weights[i - 1];
179+
}
180+
181+
total_weight -= &old_w;
182+
total_weight += w;
183+
prev_i = Some(i);
184+
}
185+
if total_weight == zero {
186+
return Err(WeightedError::AllWeightsZero);
187+
}
188+
189+
// Update the weights. Because we checked all the preconditions in the
190+
// previous loop, this should never panic.
191+
let mut iter = new_weights.iter();
192+
193+
let mut prev_weight = zero.clone();
194+
let mut next_new_weight = iter.next();
195+
let &(first_new_index, _) = next_new_weight.unwrap();
196+
let mut cumulative_weight = if first_new_index > 0 {
197+
self.cumulative_weights[first_new_index - 1].clone()
198+
} else {
199+
zero.clone()
200+
};
201+
for i in first_new_index..self.cumulative_weights.len() {
202+
match next_new_weight {
203+
Some(&(j, w)) if i == j => {
204+
cumulative_weight += w;
205+
next_new_weight = iter.next();
206+
},
207+
_ => {
208+
let mut tmp = self.cumulative_weights[i].clone();
209+
tmp -= &prev_weight; // We know this is positive.
210+
cumulative_weight += &tmp;
211+
}
212+
}
213+
prev_weight = cumulative_weight.clone();
214+
core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
215+
}
216+
217+
self.total_weight = total_weight;
218+
self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone());
219+
220+
Ok(())
131221
}
132222
}
133223

@@ -201,6 +291,31 @@ mod test {
201291
assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::InvalidWeight);
202292
assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::InvalidWeight);
203293
}
294+
295+
#[test]
296+
fn test_update_weights() {
297+
let data = [
298+
(&[10u32, 2, 3, 4][..],
299+
&[(1, &100), (2, &4)][..], // positive change
300+
&[10, 100, 4, 4][..]),
301+
(&[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
302+
&[(2, &1), (5, &1), (13, &100)][..], // negative change and last element
303+
&[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..]),
304+
];
305+
306+
for (weights, update, expected_weights) in data.into_iter() {
307+
let total_weight = weights.iter().sum::<u32>();
308+
let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
309+
assert_eq!(distr.total_weight, total_weight);
310+
311+
distr.update_weights(update).unwrap();
312+
let expected_total_weight = expected_weights.iter().sum::<u32>();
313+
let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap();
314+
assert_eq!(distr.total_weight, expected_total_weight);
315+
assert_eq!(distr.total_weight, expected_distr.total_weight);
316+
assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights);
317+
}
318+
}
204319
}
205320

206321
/// Error type returned from `WeightedIndex::new`.

0 commit comments

Comments
 (0)