@@ -84,6 +84,7 @@ use core::fmt;
84
84
#[ derive( Debug , Clone ) ]
85
85
pub struct WeightedIndex < X : SampleUniform + PartialOrd > {
86
86
cumulative_weights : Vec < X > ,
87
+ total_weight : X ,
87
88
weight_distribution : X :: Sampler ,
88
89
}
89
90
@@ -125,9 +126,98 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
125
126
if total_weight == zero {
126
127
return Err ( WeightedError :: AllWeightsZero ) ;
127
128
}
128
- let distr = X :: Sampler :: new ( zero, total_weight) ;
129
+ let distr = X :: Sampler :: new ( zero, total_weight. clone ( ) ) ;
129
130
130
- Ok ( WeightedIndex { cumulative_weights : weights, weight_distribution : distr } )
131
+ Ok ( WeightedIndex { cumulative_weights : weights, total_weight, weight_distribution : distr } )
132
+ }
133
+
134
+ /// Update a subset of weights, without changing the number of weights.
135
+ ///
136
+ /// `new_weights` must be sorted by the index.
137
+ ///
138
+ /// Using this method instead of `new` might be more efficient if only a small number of
139
+ /// weights is modified. No allocations are performed, unless the weight type `X` uses
140
+ /// allocation internally.
141
+ ///
142
+ /// In case of error, `self` is not modified.
143
+ pub fn update_weights ( & mut self , new_weights : & [ ( usize , & X ) ] ) -> Result < ( ) , WeightedError >
144
+ where X : for < ' a > :: core:: ops:: AddAssign < & ' a X > +
145
+ for < ' a > :: core:: ops:: SubAssign < & ' a X > +
146
+ Clone +
147
+ Default {
148
+ if new_weights. is_empty ( ) {
149
+ return Ok ( ( ) ) ;
150
+ }
151
+
152
+ let zero = <X as Default >:: default ( ) ;
153
+
154
+ let mut total_weight = self . total_weight . clone ( ) ;
155
+
156
+ // Check for errors first, so we don't modify `self` in case something
157
+ // goes wrong.
158
+ let mut prev_i = None ;
159
+ for & ( i, w) in new_weights {
160
+ if let Some ( old_i) = prev_i {
161
+ if old_i >= i {
162
+ return Err ( WeightedError :: InvalidWeight ) ;
163
+ }
164
+ }
165
+ if * w < zero {
166
+ return Err ( WeightedError :: InvalidWeight ) ;
167
+ }
168
+ if i >= self . cumulative_weights . len ( ) + 1 {
169
+ return Err ( WeightedError :: TooMany ) ;
170
+ }
171
+
172
+ let mut old_w = if i < self . cumulative_weights . len ( ) {
173
+ self . cumulative_weights [ i] . clone ( )
174
+ } else {
175
+ self . total_weight . clone ( )
176
+ } ;
177
+ if i > 0 {
178
+ old_w -= & self . cumulative_weights [ i - 1 ] ;
179
+ }
180
+
181
+ total_weight -= & old_w;
182
+ total_weight += w;
183
+ prev_i = Some ( i) ;
184
+ }
185
+ if total_weight == zero {
186
+ return Err ( WeightedError :: AllWeightsZero ) ;
187
+ }
188
+
189
+ // Update the weights. Because we checked all the preconditions in the
190
+ // previous loop, this should never panic.
191
+ let mut iter = new_weights. iter ( ) ;
192
+
193
+ let mut prev_weight = zero. clone ( ) ;
194
+ let mut next_new_weight = iter. next ( ) ;
195
+ let & ( first_new_index, _) = next_new_weight. unwrap ( ) ;
196
+ let mut cumulative_weight = if first_new_index > 0 {
197
+ self . cumulative_weights [ first_new_index - 1 ] . clone ( )
198
+ } else {
199
+ zero. clone ( )
200
+ } ;
201
+ for i in first_new_index..self . cumulative_weights . len ( ) {
202
+ match next_new_weight {
203
+ Some ( & ( j, w) ) if i == j => {
204
+ cumulative_weight += w;
205
+ next_new_weight = iter. next ( ) ;
206
+ } ,
207
+ _ => {
208
+ let mut tmp = self . cumulative_weights [ i] . clone ( ) ;
209
+ tmp -= & prev_weight; // We know this is positive.
210
+ cumulative_weight += & tmp;
211
+ }
212
+ }
213
+ prev_weight = cumulative_weight. clone ( ) ;
214
+ core:: mem:: swap ( & mut prev_weight, & mut self . cumulative_weights [ i] ) ;
215
+ }
216
+
217
+ self . total_weight = total_weight;
218
+ self . weight_distribution = X :: Sampler :: new ( zero, self . total_weight . clone ( ) ) ;
219
+
220
+ Ok ( ( ) )
131
221
}
132
222
}
133
223
@@ -201,6 +291,31 @@ mod test {
201
291
assert_eq ! ( WeightedIndex :: new( & [ -10 , 20 , 1 , 30 ] ) . unwrap_err( ) , WeightedError :: InvalidWeight ) ;
202
292
assert_eq ! ( WeightedIndex :: new( & [ -10 ] ) . unwrap_err( ) , WeightedError :: InvalidWeight ) ;
203
293
}
294
+
295
+ #[ test]
296
+ fn test_update_weights ( ) {
297
+ let data = [
298
+ ( & [ 10u32 , 2 , 3 , 4 ] [ ..] ,
299
+ & [ ( 1 , & 100 ) , ( 2 , & 4 ) ] [ ..] , // positive change
300
+ & [ 10 , 100 , 4 , 4 ] [ ..] ) ,
301
+ ( & [ 1u32 , 2 , 3 , 0 , 5 , 6 , 7 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ] [ ..] ,
302
+ & [ ( 2 , & 1 ) , ( 5 , & 1 ) , ( 13 , & 100 ) ] [ ..] , // negative change and last element
303
+ & [ 1u32 , 2 , 1 , 0 , 5 , 1 , 7 , 1 , 2 , 3 , 4 , 5 , 6 , 100 ] [ ..] ) ,
304
+ ] ;
305
+
306
+ for ( weights, update, expected_weights) in data. into_iter ( ) {
307
+ let total_weight = weights. iter ( ) . sum :: < u32 > ( ) ;
308
+ let mut distr = WeightedIndex :: new ( weights. to_vec ( ) ) . unwrap ( ) ;
309
+ assert_eq ! ( distr. total_weight, total_weight) ;
310
+
311
+ distr. update_weights ( update) . unwrap ( ) ;
312
+ let expected_total_weight = expected_weights. iter ( ) . sum :: < u32 > ( ) ;
313
+ let expected_distr = WeightedIndex :: new ( expected_weights. to_vec ( ) ) . unwrap ( ) ;
314
+ assert_eq ! ( distr. total_weight, expected_total_weight) ;
315
+ assert_eq ! ( distr. total_weight, expected_distr. total_weight) ;
316
+ assert_eq ! ( distr. cumulative_weights, expected_distr. cumulative_weights) ;
317
+ }
318
+ }
204
319
}
205
320
206
321
/// Error type returned from `WeightedIndex::new`.
0 commit comments