@@ -8,10 +8,9 @@ use std::{
8
8
9
9
use concurrent_queue:: ConcurrentQueue ;
10
10
use futures_lite:: { future, FutureExt } ;
11
- use is_main_thread:: is_main_thread;
12
11
13
- use crate :: MainThreadExecutor ;
14
12
use crate :: Task ;
13
+ use crate :: { main_thread_executor:: MainThreadSpawner , MainThreadExecutor } ;
15
14
16
15
/// Used to create a [`TaskPool`]
17
16
#[ derive( Debug , Default , Clone ) ]
@@ -246,16 +245,16 @@ impl TaskPool {
246
245
// transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety.
247
246
let executor: & async_executor:: Executor = & self . executor ;
248
247
let executor: & ' env async_executor:: Executor = unsafe { mem:: transmute ( executor) } ;
249
- let task_scope_executor = MainThreadExecutor :: init ( ) ;
250
- let task_scope_executor : & ' env async_executor :: Executor =
251
- unsafe { mem:: transmute ( task_scope_executor ) } ;
248
+ let main_thread_spawner = MainThreadExecutor :: init ( ) . spawner ( ) ;
249
+ let main_thread_spawner : MainThreadSpawner < ' env > =
250
+ unsafe { mem:: transmute ( main_thread_spawner ) } ;
252
251
let spawned: ConcurrentQueue < async_executor:: Task < T > > = ConcurrentQueue :: unbounded ( ) ;
253
252
let spawned_ref: & ' env ConcurrentQueue < async_executor:: Task < T > > =
254
253
unsafe { mem:: transmute ( & spawned) } ;
255
254
256
255
let scope = Scope {
257
256
executor,
258
- task_scope_executor ,
257
+ main_thread_spawner ,
259
258
spawned : spawned_ref,
260
259
scope : PhantomData ,
261
260
env : PhantomData ,
@@ -278,20 +277,10 @@ impl TaskPool {
278
277
results
279
278
} ;
280
279
281
- let is_main = if let Some ( is_main) = is_main_thread ( ) {
282
- is_main
283
- } else {
284
- false
285
- } ;
286
-
287
- if is_main {
280
+ if let Some ( main_thread_ticker) = MainThreadExecutor :: get ( ) . ticker ( ) {
288
281
let tick_forever = async move {
289
282
loop {
290
- if let Some ( is_main) = is_main_thread ( ) {
291
- if is_main {
292
- task_scope_executor. tick ( ) . await ;
293
- }
294
- }
283
+ main_thread_ticker. tick ( ) . await ;
295
284
}
296
285
} ;
297
286
@@ -372,7 +361,7 @@ impl Drop for TaskPool {
372
361
#[ derive( Debug ) ]
373
362
pub struct Scope < ' scope , ' env : ' scope , T > {
374
363
executor : & ' scope async_executor:: Executor < ' scope > ,
375
- task_scope_executor : & ' scope async_executor :: Executor < ' scope > ,
364
+ main_thread_spawner : MainThreadSpawner < ' scope > ,
376
365
spawned : & ' scope ConcurrentQueue < async_executor:: Task < T > > ,
377
366
// make `Scope` invariant over 'scope and 'env
378
367
scope : PhantomData < & ' scope mut & ' scope ( ) > ,
@@ -401,8 +390,10 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
401
390
/// [`Scope::spawn`] instead, unless the provided future needs to run on the scope's thread.
402
391
///
403
392
/// For more information, see [`TaskPool::scope`].
404
- pub fn spawn_on_scope < Fut : Future < Output = T > + ' scope + Send > ( & self , f : Fut ) {
405
- let task = self . task_scope_executor . spawn ( f) ;
393
+ pub fn spawn_on_main < Fut : Future < Output = T > + ' scope + Send > ( & self , f : Fut ) {
394
+ let main_thread_spawner: & MainThreadSpawner < ' scope > =
395
+ unsafe { mem:: transmute ( & self . main_thread_spawner ) } ;
396
+ let task = main_thread_spawner. spawn ( f) ;
406
397
// ConcurrentQueue only errors when closed or full, but we never
407
398
// close and use an unbouded queue, so it is safe to unwrap
408
399
self . spawned . push ( task) . unwrap ( ) ;
@@ -473,7 +464,7 @@ mod tests {
473
464
} ) ;
474
465
} else {
475
466
let count_clone = local_count. clone ( ) ;
476
- scope. spawn_on_scope ( async move {
467
+ scope. spawn_on_main ( async move {
477
468
if * foo != 42 {
478
469
panic ! ( "not 42!?!?" )
479
470
} else {
@@ -514,7 +505,7 @@ mod tests {
514
505
} ) ;
515
506
let spawner = std:: thread:: current ( ) . id ( ) ;
516
507
let inner_count_clone = count_clone. clone ( ) ;
517
- scope. spawn_on_scope ( async move {
508
+ scope. spawn_on_main ( async move {
518
509
inner_count_clone. fetch_add ( 1 , Ordering :: Release ) ;
519
510
if std:: thread:: current ( ) . id ( ) != spawner {
520
511
// NOTE: This check is using an atomic rather than simply panicing the
@@ -589,7 +580,7 @@ mod tests {
589
580
inner_count_clone. fetch_add ( 1 , Ordering :: Release ) ;
590
581
591
582
// spawning on the scope from another thread runs the futures on the scope's thread
592
- scope. spawn_on_scope ( async move {
583
+ scope. spawn_on_main ( async move {
593
584
inner_count_clone. fetch_add ( 1 , Ordering :: Release ) ;
594
585
if std:: thread:: current ( ) . id ( ) != spawner {
595
586
// NOTE: This check is using an atomic rather than simply panicing the
0 commit comments