@@ -10,6 +10,7 @@ use std::fmt;
10
10
use std:: future:: Future ;
11
11
use std:: marker:: PhantomData ;
12
12
use std:: pin:: Pin ;
13
+ use std:: rc:: Rc ;
13
14
use std:: task:: Poll ;
14
15
15
16
use pin_project_lite:: pin_project;
@@ -215,7 +216,7 @@ cfg_rt! {
215
216
tick: Cell <u8 >,
216
217
217
218
/// State available from thread-local.
218
- context: Context ,
219
+ context: Rc < Context > ,
219
220
220
221
/// This type should not be Send.
221
222
_not_send: PhantomData <* const ( ) >,
@@ -260,7 +261,7 @@ pin_project! {
260
261
}
261
262
}
262
263
263
- scoped_thread_local ! ( static CURRENT : Context ) ;
264
+ thread_local ! ( static CURRENT : Cell < Option < Rc < Context >>> = Cell :: new ( None ) ) ;
264
265
265
266
cfg_rt ! {
266
267
/// Spawns a `!Send` future on the local task set.
@@ -310,10 +311,12 @@ cfg_rt! {
310
311
F :: Output : ' static
311
312
{
312
313
CURRENT . with( |maybe_cx| {
313
- let cx = maybe_cx
314
- . expect( "`spawn_local` called from outside of a `task::LocalSet`" ) ;
314
+ let ctx = clone_rc( maybe_cx) ;
315
+ match ctx {
316
+ None => panic!( "`spawn_local` called from outside of a `task::LocalSet`" ) ,
317
+ Some ( cx) => cx. spawn( future, name)
318
+ }
315
319
316
- cx. spawn( future, name)
317
320
} )
318
321
}
319
322
}
@@ -327,12 +330,29 @@ const MAX_TASKS_PER_TICK: usize = 61;
327
330
/// How often it check the remote queue first.
328
331
const REMOTE_FIRST_INTERVAL : u8 = 31 ;
329
332
333
+ /// Context guard for LocalSet
334
+ pub struct LocalEnterGuard ( Option < Rc < Context > > ) ;
335
+
336
+ impl Drop for LocalEnterGuard {
337
+ fn drop ( & mut self ) {
338
+ CURRENT . with ( |ctx| {
339
+ ctx. replace ( self . 0 . take ( ) ) ;
340
+ } )
341
+ }
342
+ }
343
+
344
+ impl fmt:: Debug for LocalEnterGuard {
345
+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
346
+ f. debug_struct ( "LocalEnterGuard" ) . finish ( )
347
+ }
348
+ }
349
+
330
350
impl LocalSet {
331
351
/// Returns a new local task set.
332
352
pub fn new ( ) -> LocalSet {
333
353
LocalSet {
334
354
tick : Cell :: new ( 0 ) ,
335
- context : Context {
355
+ context : Rc :: new ( Context {
336
356
owned : LocalOwnedTasks :: new ( ) ,
337
357
queue : VecDequeCell :: with_capacity ( INITIAL_CAPACITY ) ,
338
358
shared : Arc :: new ( Shared {
@@ -342,11 +362,24 @@ impl LocalSet {
342
362
unhandled_panic : crate :: runtime:: UnhandledPanic :: Ignore ,
343
363
} ) ,
344
364
unhandled_panic : Cell :: new ( false ) ,
345
- } ,
365
+ } ) ,
346
366
_not_send : PhantomData ,
347
367
}
348
368
}
349
369
370
+ /// Enters the context of this `LocalSet`.
371
+ ///
372
+ /// The [`spawn_local`] method will spawn tasks on the `LocalSet` whose
373
+ /// context you are inside.
374
+ ///
375
+ /// [`spawn_local`]: fn@crate::task::spawn_local
376
+ pub fn enter ( & self ) -> LocalEnterGuard {
377
+ CURRENT . with ( |ctx| {
378
+ let old = ctx. replace ( Some ( self . context . clone ( ) ) ) ;
379
+ LocalEnterGuard ( old)
380
+ } )
381
+ }
382
+
350
383
/// Spawns a `!Send` task onto the local task set.
351
384
///
352
385
/// This task is guaranteed to be run on the current thread.
@@ -579,7 +612,25 @@ impl LocalSet {
579
612
}
580
613
581
614
fn with < T > ( & self , f : impl FnOnce ( ) -> T ) -> T {
582
- CURRENT . set ( & self . context , f)
615
+ CURRENT . with ( |ctx| {
616
+ struct Reset < ' a > {
617
+ ctx_ref : & ' a Cell < Option < Rc < Context > > > ,
618
+ val : Option < Rc < Context > > ,
619
+ }
620
+ impl < ' a > Drop for Reset < ' a > {
621
+ fn drop ( & mut self ) {
622
+ self . ctx_ref . replace ( self . val . take ( ) ) ;
623
+ }
624
+ }
625
+ let old = ctx. replace ( Some ( self . context . clone ( ) ) ) ;
626
+
627
+ let _reset = Reset {
628
+ ctx_ref : ctx,
629
+ val : old,
630
+ } ;
631
+
632
+ f ( )
633
+ } )
583
634
}
584
635
}
585
636
@@ -645,8 +696,9 @@ cfg_unstable! {
645
696
/// [`JoinHandle`]: struct@crate::task::JoinHandle
646
697
pub fn unhandled_panic( & mut self , behavior: crate :: runtime:: UnhandledPanic ) -> & mut Self {
647
698
// TODO: This should be set as a builder
648
- Arc :: get_mut( & mut self . context. shared)
649
- . expect( "TODO: we shouldn't panic" )
699
+ Rc :: get_mut( & mut self . context)
700
+ . and_then( |ctx| Arc :: get_mut( & mut ctx. shared) )
701
+ . expect( "Unhandled Panic behavior modified after starting LocalSet" )
650
702
. unhandled_panic = behavior;
651
703
self
652
704
}
@@ -769,23 +821,33 @@ impl<T: Future> Future for RunUntil<'_, T> {
769
821
}
770
822
}
771
823
824
+ fn clone_rc < T > ( rc : & Cell < Option < Rc < T > > > ) -> Option < Rc < T > > {
825
+ let value = rc. take ( ) ;
826
+ let cloned = value. clone ( ) ;
827
+ rc. set ( value) ;
828
+ cloned
829
+ }
830
+
772
831
impl Shared {
773
832
/// Schedule the provided task on the scheduler.
774
833
fn schedule ( & self , task : task:: Notified < Arc < Self > > ) {
775
- CURRENT . with ( |maybe_cx| match maybe_cx {
776
- Some ( cx) if cx. shared . ptr_eq ( self ) => {
777
- cx. queue . push_back ( task) ;
778
- }
779
- _ => {
780
- // First check whether the queue is still there (if not, the
781
- // LocalSet is dropped). Then push to it if so, and if not,
782
- // do nothing.
783
- let mut lock = self . queue . lock ( ) ;
784
-
785
- if let Some ( queue) = lock. as_mut ( ) {
786
- queue. push_back ( task) ;
787
- drop ( lock) ;
788
- self . waker . wake ( ) ;
834
+ CURRENT . with ( |maybe_cx| {
835
+ let ctx = clone_rc ( maybe_cx) ;
836
+ match ctx {
837
+ Some ( cx) if cx. shared . ptr_eq ( self ) => {
838
+ cx. queue . push_back ( task) ;
839
+ }
840
+ _ => {
841
+ // First check whether the queue is still there (if not, the
842
+ // LocalSet is dropped). Then push to it if so, and if not,
843
+ // do nothing.
844
+ let mut lock = self . queue . lock ( ) ;
845
+
846
+ if let Some ( queue) = lock. as_mut ( ) {
847
+ queue. push_back ( task) ;
848
+ drop ( lock) ;
849
+ self . waker . wake ( ) ;
850
+ }
789
851
}
790
852
}
791
853
} ) ;
@@ -799,9 +861,14 @@ impl Shared {
799
861
impl task:: Schedule for Arc < Shared > {
800
862
fn release ( & self , task : & Task < Self > ) -> Option < Task < Self > > {
801
863
CURRENT . with ( |maybe_cx| {
802
- let cx = maybe_cx. expect ( "scheduler context missing" ) ;
803
- assert ! ( cx. shared. ptr_eq( self ) ) ;
804
- cx. owned . remove ( task)
864
+ let ctx = clone_rc ( maybe_cx) ;
865
+ match ctx {
866
+ None => panic ! ( "scheduler context missing" ) ,
867
+ Some ( cx) => {
868
+ assert ! ( cx. shared. ptr_eq( self ) ) ;
869
+ cx. owned . remove ( task)
870
+ }
871
+ }
805
872
} )
806
873
}
807
874
@@ -821,13 +888,15 @@ impl task::Schedule for Arc<Shared> {
821
888
// This hook is only called from within the runtime, so
822
889
// `CURRENT` should match with `&self`, i.e. there is no
823
890
// opportunity for a nested scheduler to be called.
824
- CURRENT . with( |maybe_cx| match maybe_cx {
891
+ CURRENT . with( |maybe_cx| {
892
+ let ctx = clone_rc( maybe_cx) ;
893
+ match ctx {
825
894
Some ( cx) if Arc :: ptr_eq( self , & cx. shared) => {
826
895
cx. unhandled_panic. set( true ) ;
827
896
cx. owned. close_and_shutdown_all( ) ;
828
897
}
829
898
_ => unreachable!( "runtime core not set in CURRENT thread-local" ) ,
830
- } )
899
+ } } )
831
900
}
832
901
}
833
902
}
0 commit comments