@@ -2,11 +2,11 @@ use std::{
2
2
future:: Future ,
3
3
marker:: PhantomData ,
4
4
mem,
5
- pin:: Pin ,
6
5
sync:: Arc ,
7
6
thread:: { self , JoinHandle } ,
8
7
} ;
9
8
9
+ use async_task:: FallibleTask ;
10
10
use concurrent_queue:: ConcurrentQueue ;
11
11
use futures_lite:: { future, pin, FutureExt } ;
12
12
@@ -248,8 +248,8 @@ impl TaskPool {
248
248
let task_scope_executor = & async_executor:: Executor :: default ( ) ;
249
249
let task_scope_executor: & ' env async_executor:: Executor =
250
250
unsafe { mem:: transmute ( task_scope_executor) } ;
251
- let spawned: ConcurrentQueue < async_executor :: Task < T > > = ConcurrentQueue :: unbounded ( ) ;
252
- let spawned_ref: & ' env ConcurrentQueue < async_executor :: Task < T > > =
251
+ let spawned: ConcurrentQueue < FallibleTask < T > > = ConcurrentQueue :: unbounded ( ) ;
252
+ let spawned_ref: & ' env ConcurrentQueue < FallibleTask < T > > =
253
253
unsafe { mem:: transmute ( & spawned) } ;
254
254
255
255
let scope = Scope {
@@ -267,10 +267,10 @@ impl TaskPool {
267
267
if spawned. is_empty ( ) {
268
268
Vec :: new ( )
269
269
} else {
270
- let get_results = async move {
271
- let mut results = Vec :: with_capacity ( spawned . len ( ) ) ;
272
- while let Ok ( task) = spawned . pop ( ) {
273
- results. push ( task. await ) ;
270
+ let get_results = async {
271
+ let mut results = Vec :: with_capacity ( spawned_ref . len ( ) ) ;
272
+ while let Ok ( task) = spawned_ref . pop ( ) {
273
+ results. push ( task. await . unwrap ( ) ) ;
274
274
}
275
275
276
276
results
@@ -279,23 +279,8 @@ impl TaskPool {
279
279
// Pin the futures on the stack.
280
280
pin ! ( get_results) ;
281
281
282
- // SAFETY: This function blocks until all futures complete, so we do not read/write
283
- // the data from futures outside of the 'scope lifetime. However,
284
- // rust has no way of knowing this so we must convert to 'static
285
- // here to appease the compiler as it is unable to validate safety.
286
- let get_results: Pin < & mut ( dyn Future < Output = Vec < T > > + ' static + Send ) > = get_results;
287
- let get_results: Pin < & ' static mut ( dyn Future < Output = Vec < T > > + ' static + Send ) > =
288
- unsafe { mem:: transmute ( get_results) } ;
289
-
290
- // The thread that calls scope() will participate in driving tasks in the pool
291
- // forward until the tasks that are spawned by this scope() call
292
- // complete. (If the caller of scope() happens to be a thread in
293
- // this thread pool, and we only have one thread in the pool, then
294
- // simply calling future::block_on(spawned) would deadlock.)
295
- let mut spawned = task_scope_executor. spawn ( get_results) ;
296
-
297
282
loop {
298
- if let Some ( result) = future:: block_on ( future:: poll_once ( & mut spawned ) ) {
283
+ if let Some ( result) = future:: block_on ( future:: poll_once ( & mut get_results ) ) {
299
284
break result;
300
285
} ;
301
286
@@ -378,7 +363,7 @@ impl Drop for TaskPool {
378
363
pub struct Scope < ' scope , ' env : ' scope , T > {
379
364
executor : & ' scope async_executor:: Executor < ' scope > ,
380
365
task_scope_executor : & ' scope async_executor:: Executor < ' scope > ,
381
- spawned : & ' scope ConcurrentQueue < async_executor :: Task < T > > ,
366
+ spawned : & ' scope ConcurrentQueue < FallibleTask < T > > ,
382
367
// make `Scope` invariant over 'scope and 'env
383
368
scope : PhantomData < & ' scope mut & ' scope ( ) > ,
384
369
env : PhantomData < & ' env mut & ' env ( ) > ,
@@ -394,7 +379,7 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
394
379
///
395
380
/// For more information, see [`TaskPool::scope`].
396
381
pub fn spawn < Fut : Future < Output = T > + ' scope + Send > ( & self , f : Fut ) {
397
- let task = self . executor . spawn ( f) ;
382
+ let task = self . executor . spawn ( f) . fallible ( ) ;
398
383
// ConcurrentQueue only errors when closed or full, but we never
399
384
// close and use an unbouded queue, so it is safe to unwrap
400
385
self . spawned . push ( task) . unwrap ( ) ;
@@ -407,13 +392,26 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
407
392
///
408
393
/// For more information, see [`TaskPool::scope`].
409
394
pub fn spawn_on_scope < Fut : Future < Output = T > + ' scope + Send > ( & self , f : Fut ) {
410
- let task = self . task_scope_executor . spawn ( f) ;
395
+ let task = self . task_scope_executor . spawn ( f) . fallible ( ) ;
411
396
// ConcurrentQueue only errors when closed or full, but we never
412
397
// close and use an unbouded queue, so it is safe to unwrap
413
398
self . spawned . push ( task) . unwrap ( ) ;
414
399
}
415
400
}
416
401
402
+ impl < ' scope , ' env , T > Drop for Scope < ' scope , ' env , T >
403
+ where
404
+ T : ' scope ,
405
+ {
406
+ fn drop ( & mut self ) {
407
+ future:: block_on ( async {
408
+ while let Ok ( task) = self . spawned . pop ( ) {
409
+ task. cancel ( ) . await ;
410
+ }
411
+ } ) ;
412
+ }
413
+ }
414
+
417
415
#[ cfg( test) ]
418
416
#[ allow( clippy:: disallowed_types) ]
419
417
mod tests {
0 commit comments