15
15
16
16
use alloc:: sync:: Arc ;
17
17
use core:: mem;
18
- use crate :: sync:: { Condvar , Mutex } ;
18
+ use crate :: sync:: { Condvar , Mutex , MutexGuard } ;
19
19
20
20
use crate :: prelude:: * ;
21
21
@@ -33,6 +33,20 @@ pub(crate) struct Notifier {
33
33
condvar : Condvar ,
34
34
}
35
35
36
+ macro_rules! check_woken {
37
+ ( $guard: expr, $retval: expr) => { {
38
+ if $guard. 0 {
39
+ $guard. 0 = false ;
40
+ if $guard. 1 . as_ref( ) . map( |l| l. lock( ) . unwrap( ) . complete) . unwrap_or( false ) {
41
+ // If we're about to return as woken, and the future state is marked complete, wipe
42
+ // the future state and let the next future wait until we get a new notify.
43
+ $guard. 1 . take( ) ;
44
+ }
45
+ return $retval;
46
+ }
47
+ } }
48
+ }
49
+
36
50
impl Notifier {
37
51
pub ( crate ) fn new ( ) -> Self {
38
52
Self {
@@ -41,45 +55,47 @@ impl Notifier {
41
55
}
42
56
}
43
57
58
+ fn propagate_future_state_to_notify_flag ( & self ) -> MutexGuard < ( bool , Option < Arc < Mutex < FutureState > > > ) > {
59
+ let mut lock = self . notify_pending . lock ( ) . unwrap ( ) ;
60
+ if let Some ( existing_state) = & lock. 1 {
61
+ if existing_state. lock ( ) . unwrap ( ) . callbacks_made {
62
+ // If the existing `FutureState` has completed and actually made callbacks,
63
+ // consider the notification flag to have been cleared and reset the future state.
64
+ lock. 1 . take ( ) ;
65
+ lock. 0 = false ;
66
+ }
67
+ }
68
+ lock
69
+ }
70
+
44
71
pub ( crate ) fn wait ( & self ) {
45
72
loop {
46
- let mut guard = self . notify_pending . lock ( ) . unwrap ( ) ;
47
- if guard. 0 {
48
- guard. 0 = false ;
49
- return ;
50
- }
73
+ let mut guard = self . propagate_future_state_to_notify_flag ( ) ;
74
+ check_woken ! ( guard, ( ) ) ;
51
75
guard = self . condvar . wait ( guard) . unwrap ( ) ;
52
- let result = guard. 0 ;
53
- if result {
54
- guard. 0 = false ;
55
- return
56
- }
76
+ check_woken ! ( guard, ( ) ) ;
57
77
}
58
78
}
59
79
60
80
#[ cfg( any( test, feature = "std" ) ) ]
61
81
pub ( crate ) fn wait_timeout ( & self , max_wait : Duration ) -> bool {
62
82
let current_time = Instant :: now ( ) ;
63
83
loop {
64
- let mut guard = self . notify_pending . lock ( ) . unwrap ( ) ;
65
- if guard. 0 {
66
- guard. 0 = false ;
67
- return true ;
68
- }
84
+ let mut guard = self . propagate_future_state_to_notify_flag ( ) ;
85
+ check_woken ! ( guard, true ) ;
69
86
guard = self . condvar . wait_timeout ( guard, max_wait) . unwrap ( ) . 0 ;
87
+ check_woken ! ( guard, true ) ;
70
88
// Due to spurious wakeups that can happen on `wait_timeout`, here we need to check if the
71
89
// desired wait time has actually passed, and if not then restart the loop with a reduced wait
72
90
// time. Note that this logic can be highly simplified through the use of
73
91
// `Condvar::wait_while` and `Condvar::wait_timeout_while`, if and when our MSRV is raised to
74
92
// 1.42.0.
75
93
let elapsed = current_time. elapsed ( ) ;
76
- let result = guard. 0 ;
77
- if result || elapsed >= max_wait {
78
- guard. 0 = false ;
79
- return result;
94
+ if elapsed >= max_wait {
95
+ return false ;
80
96
}
81
97
match max_wait. checked_sub ( elapsed) {
82
- None => return result ,
98
+ None => return false ,
83
99
Some ( _) => continue
84
100
}
85
101
}
@@ -88,17 +104,8 @@ impl Notifier {
88
104
/// Wake waiters, tracking that wake needs to occur even if there are currently no waiters.
89
105
pub ( crate ) fn notify ( & self ) {
90
106
let mut lock = self . notify_pending . lock ( ) . unwrap ( ) ;
91
- let mut future_probably_generated_calls = false ;
92
- if let Some ( future_state) = lock. 1 . take ( ) {
93
- future_probably_generated_calls |= future_state. lock ( ) . unwrap ( ) . complete ( ) ;
94
- future_probably_generated_calls |= Arc :: strong_count ( & future_state) > 1 ;
95
- }
96
- if future_probably_generated_calls {
97
- // If a future made some callbacks or has not yet been drop'd (i.e. the state has more
98
- // than the one reference we hold), assume the user was notified and skip setting the
99
- // notification-required flag. This will not cause the `wait` functions above to return
100
- // and avoid any future `Future`s starting in a completed state.
101
- return ;
107
+ if let Some ( future_state) = & lock. 1 {
108
+ future_state. lock ( ) . unwrap ( ) . complete ( ) ;
102
109
}
103
110
lock. 0 = true ;
104
111
mem:: drop ( lock) ;
@@ -107,20 +114,14 @@ impl Notifier {
107
114
108
115
/// Gets a [`Future`] that will get woken up with any waiters
109
116
pub ( crate ) fn get_future ( & self ) -> Future {
110
- let mut lock = self . notify_pending . lock ( ) . unwrap ( ) ;
111
- if lock. 0 {
112
- Future {
113
- state : Arc :: new ( Mutex :: new ( FutureState {
114
- callbacks : Vec :: new ( ) ,
115
- complete : true ,
116
- } ) )
117
- }
118
- } else if let Some ( existing_state) = & lock. 1 {
117
+ let mut lock = self . propagate_future_state_to_notify_flag ( ) ;
118
+ if let Some ( existing_state) = & lock. 1 {
119
119
Future { state : Arc :: clone ( & existing_state) }
120
120
} else {
121
121
let state = Arc :: new ( Mutex :: new ( FutureState {
122
122
callbacks : Vec :: new ( ) ,
123
- complete : false ,
123
+ complete : lock. 0 ,
124
+ callbacks_made : false ,
124
125
} ) ) ;
125
126
lock. 1 = Some ( Arc :: clone ( & state) ) ;
126
127
Future { state }
@@ -151,19 +152,21 @@ impl<F: Fn() + Send> FutureCallback for F {
151
152
}
152
153
153
154
pub ( crate ) struct FutureState {
154
- callbacks : Vec < Box < dyn FutureCallback > > ,
155
+ // When we're tracking whether a callback counts as having woken the user's code, we check the
156
+ // first bool - set to false if we're just calling a Waker, and true if we're calling an actual
157
+ // user-provided function.
158
+ callbacks : Vec < ( bool , Box < dyn FutureCallback > ) > ,
155
159
complete : bool ,
160
+ callbacks_made : bool ,
156
161
}
157
162
158
163
impl FutureState {
159
- fn complete ( & mut self ) -> bool {
160
- let mut made_calls = false ;
161
- for callback in self . callbacks . drain ( ..) {
164
+ fn complete ( & mut self ) {
165
+ for ( counts_as_call, callback) in self . callbacks . drain ( ..) {
162
166
callback. call ( ) ;
163
- made_calls = true ;
167
+ self . callbacks_made |= counts_as_call ;
164
168
}
165
169
self . complete = true ;
166
- made_calls
167
170
}
168
171
}
169
172
@@ -180,10 +183,11 @@ impl Future {
180
183
pub fn register_callback ( & self , callback : Box < dyn FutureCallback > ) {
181
184
let mut state = self . state . lock ( ) . unwrap ( ) ;
182
185
if state. complete {
186
+ state. callbacks_made = true ;
183
187
mem:: drop ( state) ;
184
188
callback. call ( ) ;
185
189
} else {
186
- state. callbacks . push ( callback) ;
190
+ state. callbacks . push ( ( true , callback) ) ;
187
191
}
188
192
}
189
193
@@ -198,12 +202,10 @@ impl Future {
198
202
}
199
203
}
200
204
201
- mod std_future {
202
- use core:: task:: Waker ;
203
- pub struct StdWaker ( pub Waker ) ;
204
- impl super :: FutureCallback for StdWaker {
205
- fn call ( & self ) { self . 0 . wake_by_ref ( ) }
206
- }
205
+ use core:: task:: Waker ;
206
+ struct StdWaker ( pub Waker ) ;
207
+ impl FutureCallback for StdWaker {
208
+ fn call ( & self ) { self . 0 . wake_by_ref ( ) }
207
209
}
208
210
209
211
/// (C-not exported) as Rust Futures aren't usable in language bindings.
@@ -213,10 +215,11 @@ impl<'a> StdFuture for Future {
213
215
fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
214
216
let mut state = self . state . lock ( ) . unwrap ( ) ;
215
217
if state. complete {
218
+ state. callbacks_made = true ;
216
219
Poll :: Ready ( ( ) )
217
220
} else {
218
221
let waker = cx. waker ( ) . clone ( ) ;
219
- state. callbacks . push ( Box :: new ( std_future :: StdWaker ( waker) ) ) ;
222
+ state. callbacks . push ( ( false , Box :: new ( StdWaker ( waker) ) ) ) ;
220
223
Poll :: Pending
221
224
}
222
225
}
@@ -285,6 +288,28 @@ mod tests {
285
288
assert ! ( !callback. load( Ordering :: SeqCst ) ) ;
286
289
}
287
290
291
+ #[ test]
292
+ fn new_future_wipes_notify_bit ( ) {
293
+ // Previously, if we were only using the `Future` interface to learn when a `Notifier` has
294
+ // been notified, we'd never mark the notifier as not-awaiting-notify if a `Future` is
295
+ // fetched after the notify bit has been set.
296
+ let notifier = Notifier :: new ( ) ;
297
+ notifier. notify ( ) ;
298
+
299
+ let callback = Arc :: new ( AtomicBool :: new ( false ) ) ;
300
+ let callback_ref = Arc :: clone ( & callback) ;
301
+ notifier. get_future ( ) . register_callback ( Box :: new ( move || assert ! ( !callback_ref. fetch_or( true , Ordering :: SeqCst ) ) ) ) ;
302
+ assert ! ( callback. load( Ordering :: SeqCst ) ) ;
303
+
304
+ let callback = Arc :: new ( AtomicBool :: new ( false ) ) ;
305
+ let callback_ref = Arc :: clone ( & callback) ;
306
+ notifier. get_future ( ) . register_callback ( Box :: new ( move || assert ! ( !callback_ref. fetch_or( true , Ordering :: SeqCst ) ) ) ) ;
307
+ assert ! ( !callback. load( Ordering :: SeqCst ) ) ;
308
+
309
+ notifier. notify ( ) ;
310
+ assert ! ( callback. load( Ordering :: SeqCst ) ) ;
311
+ }
312
+
288
313
#[ cfg( feature = "std" ) ]
289
314
#[ test]
290
315
fn test_wait_timeout ( ) {
@@ -336,6 +361,7 @@ mod tests {
336
361
state : Arc :: new ( Mutex :: new ( FutureState {
337
362
callbacks : Vec :: new ( ) ,
338
363
complete : false ,
364
+ callbacks_made : false ,
339
365
} ) )
340
366
} ;
341
367
let callback = Arc :: new ( AtomicBool :: new ( false ) ) ;
@@ -354,6 +380,7 @@ mod tests {
354
380
state : Arc :: new ( Mutex :: new ( FutureState {
355
381
callbacks : Vec :: new ( ) ,
356
382
complete : false ,
383
+ callbacks_made : false ,
357
384
} ) )
358
385
} ;
359
386
future. state . lock ( ) . unwrap ( ) . complete ( ) ;
@@ -391,6 +418,7 @@ mod tests {
391
418
state : Arc :: new ( Mutex :: new ( FutureState {
392
419
callbacks : Vec :: new ( ) ,
393
420
complete : false ,
421
+ callbacks_made : false ,
394
422
} ) )
395
423
} ;
396
424
let mut second_future = Future { state : Arc :: clone ( & future. state ) } ;
@@ -409,4 +437,36 @@ mod tests {
409
437
assert_eq ! ( Pin :: new( & mut future) . poll( & mut Context :: from_waker( & waker) ) , Poll :: Ready ( ( ) ) ) ;
410
438
assert_eq ! ( Pin :: new( & mut second_future) . poll( & mut Context :: from_waker( & second_waker) ) , Poll :: Ready ( ( ) ) ) ;
411
439
}
440
+
441
+ #[ test]
442
+ fn test_dropped_future_doesnt_count ( ) {
443
+ // Tests that if a Future gets drop'd before it is poll()ed `Ready` it doesn't count as
444
+ // having been woken, leaving the notify-required flag set.
445
+ let notifier = Notifier :: new ( ) ;
446
+ notifier. notify ( ) ;
447
+
448
+ // If we get a future and don't touch it we're definitely still notify-required.
449
+ notifier. get_future ( ) ;
450
+ assert ! ( notifier. wait_timeout( Duration :: from_millis( 1 ) ) ) ;
451
+ assert ! ( !notifier. wait_timeout( Duration :: from_millis( 1 ) ) ) ;
452
+
453
+ // Even if we poll'd once but didn't observe a `Ready`, we should be notify-required.
454
+ let mut future = notifier. get_future ( ) ;
455
+ let ( woken, waker) = create_waker ( ) ;
456
+ assert_eq ! ( Pin :: new( & mut future) . poll( & mut Context :: from_waker( & waker) ) , Poll :: Pending ) ;
457
+
458
+ notifier. notify ( ) ;
459
+ assert ! ( woken. load( Ordering :: SeqCst ) ) ;
460
+ assert ! ( notifier. wait_timeout( Duration :: from_millis( 1 ) ) ) ;
461
+
462
+ // However, once we do poll `Ready` it should wipe the notify-required flag.
463
+ let mut future = notifier. get_future ( ) ;
464
+ let ( woken, waker) = create_waker ( ) ;
465
+ assert_eq ! ( Pin :: new( & mut future) . poll( & mut Context :: from_waker( & waker) ) , Poll :: Pending ) ;
466
+
467
+ notifier. notify ( ) ;
468
+ assert ! ( woken. load( Ordering :: SeqCst ) ) ;
469
+ assert_eq ! ( Pin :: new( & mut future) . poll( & mut Context :: from_waker( & waker) ) , Poll :: Ready ( ( ) ) ) ;
470
+ assert ! ( !notifier. wait_timeout( Duration :: from_millis( 1 ) ) ) ;
471
+ }
412
472
}
0 commit comments