Skip to content

Commit 14fca34

Browse files
authored
task: add LocalSet::enter (#4736) (#4765)
1 parent 8e20cfb commit 14fca34

File tree

3 files changed

+114
-30
lines changed

3 files changed

+114
-30
lines changed

tokio/src/task/local.rs

Lines changed: 98 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::fmt;
1010
use std::future::Future;
1111
use std::marker::PhantomData;
1212
use std::pin::Pin;
13+
use std::rc::Rc;
1314
use std::task::Poll;
1415

1516
use pin_project_lite::pin_project;
@@ -215,7 +216,7 @@ cfg_rt! {
215216
tick: Cell<u8>,
216217

217218
/// State available from thread-local.
218-
context: Context,
219+
context: Rc<Context>,
219220

220221
/// This type should not be Send.
221222
_not_send: PhantomData<*const ()>,
@@ -260,7 +261,7 @@ pin_project! {
260261
}
261262
}
262263

263-
scoped_thread_local!(static CURRENT: Context);
264+
thread_local!(static CURRENT: Cell<Option<Rc<Context>>> = Cell::new(None));
264265

265266
cfg_rt! {
266267
/// Spawns a `!Send` future on the local task set.
@@ -310,10 +311,12 @@ cfg_rt! {
310311
F::Output: 'static
311312
{
312313
CURRENT.with(|maybe_cx| {
313-
let cx = maybe_cx
314-
.expect("`spawn_local` called from outside of a `task::LocalSet`");
314+
let ctx = clone_rc(maybe_cx);
315+
match ctx {
316+
None => panic!("`spawn_local` called from outside of a `task::LocalSet`"),
317+
Some(cx) => cx.spawn(future, name)
318+
}
315319

316-
cx.spawn(future, name)
317320
})
318321
}
319322
}
@@ -327,12 +330,29 @@ const MAX_TASKS_PER_TICK: usize = 61;
327330
/// How often it check the remote queue first.
328331
const REMOTE_FIRST_INTERVAL: u8 = 31;
329332

333+
/// Context guard for LocalSet
334+
pub struct LocalEnterGuard(Option<Rc<Context>>);
335+
336+
impl Drop for LocalEnterGuard {
337+
fn drop(&mut self) {
338+
CURRENT.with(|ctx| {
339+
ctx.replace(self.0.take());
340+
})
341+
}
342+
}
343+
344+
impl fmt::Debug for LocalEnterGuard {
345+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346+
f.debug_struct("LocalEnterGuard").finish()
347+
}
348+
}
349+
330350
impl LocalSet {
331351
/// Returns a new local task set.
332352
pub fn new() -> LocalSet {
333353
LocalSet {
334354
tick: Cell::new(0),
335-
context: Context {
355+
context: Rc::new(Context {
336356
owned: LocalOwnedTasks::new(),
337357
queue: VecDequeCell::with_capacity(INITIAL_CAPACITY),
338358
shared: Arc::new(Shared {
@@ -342,11 +362,24 @@ impl LocalSet {
342362
unhandled_panic: crate::runtime::UnhandledPanic::Ignore,
343363
}),
344364
unhandled_panic: Cell::new(false),
345-
},
365+
}),
346366
_not_send: PhantomData,
347367
}
348368
}
349369

370+
/// Enters the context of this `LocalSet`.
371+
///
372+
/// The [`spawn_local`] method will spawn tasks on the `LocalSet` whose
373+
/// context you are inside.
374+
///
375+
/// [`spawn_local`]: fn@crate::task::spawn_local
376+
pub fn enter(&self) -> LocalEnterGuard {
377+
CURRENT.with(|ctx| {
378+
let old = ctx.replace(Some(self.context.clone()));
379+
LocalEnterGuard(old)
380+
})
381+
}
382+
350383
/// Spawns a `!Send` task onto the local task set.
351384
///
352385
/// This task is guaranteed to be run on the current thread.
@@ -579,7 +612,25 @@ impl LocalSet {
579612
}
580613

581614
fn with<T>(&self, f: impl FnOnce() -> T) -> T {
582-
CURRENT.set(&self.context, f)
615+
CURRENT.with(|ctx| {
616+
struct Reset<'a> {
617+
ctx_ref: &'a Cell<Option<Rc<Context>>>,
618+
val: Option<Rc<Context>>,
619+
}
620+
impl<'a> Drop for Reset<'a> {
621+
fn drop(&mut self) {
622+
self.ctx_ref.replace(self.val.take());
623+
}
624+
}
625+
let old = ctx.replace(Some(self.context.clone()));
626+
627+
let _reset = Reset {
628+
ctx_ref: ctx,
629+
val: old,
630+
};
631+
632+
f()
633+
})
583634
}
584635
}
585636

@@ -645,8 +696,9 @@ cfg_unstable! {
645696
/// [`JoinHandle`]: struct@crate::task::JoinHandle
646697
pub fn unhandled_panic(&mut self, behavior: crate::runtime::UnhandledPanic) -> &mut Self {
647698
// TODO: This should be set as a builder
648-
Arc::get_mut(&mut self.context.shared)
649-
.expect("TODO: we shouldn't panic")
699+
Rc::get_mut(&mut self.context)
700+
.and_then(|ctx| Arc::get_mut(&mut ctx.shared))
701+
.expect("Unhandled Panic behavior modified after starting LocalSet")
650702
.unhandled_panic = behavior;
651703
self
652704
}
@@ -769,23 +821,33 @@ impl<T: Future> Future for RunUntil<'_, T> {
769821
}
770822
}
771823

824+
fn clone_rc<T>(rc: &Cell<Option<Rc<T>>>) -> Option<Rc<T>> {
825+
let value = rc.take();
826+
let cloned = value.clone();
827+
rc.set(value);
828+
cloned
829+
}
830+
772831
impl Shared {
773832
/// Schedule the provided task on the scheduler.
774833
fn schedule(&self, task: task::Notified<Arc<Self>>) {
775-
CURRENT.with(|maybe_cx| match maybe_cx {
776-
Some(cx) if cx.shared.ptr_eq(self) => {
777-
cx.queue.push_back(task);
778-
}
779-
_ => {
780-
// First check whether the queue is still there (if not, the
781-
// LocalSet is dropped). Then push to it if so, and if not,
782-
// do nothing.
783-
let mut lock = self.queue.lock();
784-
785-
if let Some(queue) = lock.as_mut() {
786-
queue.push_back(task);
787-
drop(lock);
788-
self.waker.wake();
834+
CURRENT.with(|maybe_cx| {
835+
let ctx = clone_rc(maybe_cx);
836+
match ctx {
837+
Some(cx) if cx.shared.ptr_eq(self) => {
838+
cx.queue.push_back(task);
839+
}
840+
_ => {
841+
// First check whether the queue is still there (if not, the
842+
// LocalSet is dropped). Then push to it if so, and if not,
843+
// do nothing.
844+
let mut lock = self.queue.lock();
845+
846+
if let Some(queue) = lock.as_mut() {
847+
queue.push_back(task);
848+
drop(lock);
849+
self.waker.wake();
850+
}
789851
}
790852
}
791853
});
@@ -799,9 +861,14 @@ impl Shared {
799861
impl task::Schedule for Arc<Shared> {
800862
fn release(&self, task: &Task<Self>) -> Option<Task<Self>> {
801863
CURRENT.with(|maybe_cx| {
802-
let cx = maybe_cx.expect("scheduler context missing");
803-
assert!(cx.shared.ptr_eq(self));
804-
cx.owned.remove(task)
864+
let ctx = clone_rc(maybe_cx);
865+
match ctx {
866+
None => panic!("scheduler context missing"),
867+
Some(cx) => {
868+
assert!(cx.shared.ptr_eq(self));
869+
cx.owned.remove(task)
870+
}
871+
}
805872
})
806873
}
807874

@@ -821,13 +888,15 @@ impl task::Schedule for Arc<Shared> {
821888
// This hook is only called from within the runtime, so
822889
// `CURRENT` should match with `&self`, i.e. there is no
823890
// opportunity for a nested scheduler to be called.
824-
CURRENT.with(|maybe_cx| match maybe_cx {
891+
CURRENT.with(|maybe_cx| {
892+
let ctx = clone_rc(maybe_cx);
893+
match ctx {
825894
Some(cx) if Arc::ptr_eq(self, &cx.shared) => {
826895
cx.unhandled_panic.set(true);
827896
cx.owned.close_and_shutdown_all();
828897
}
829898
_ => unreachable!("runtime core not set in CURRENT thread-local"),
830-
})
899+
}})
831900
}
832901
}
833902
}

tokio/src/task/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ cfg_rt! {
299299
}
300300

301301
mod local;
302-
pub use local::{spawn_local, LocalSet};
302+
pub use local::{spawn_local, LocalSet, LocalEnterGuard};
303303

304304
mod task_local;
305305
pub use task_local::LocalKey;

tokio/tests/task_local_set.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,21 @@ async fn local_threadpool_timer() {
135135
})
136136
.await;
137137
}
138+
#[test]
139+
fn enter_guard_spawn() {
140+
let local = LocalSet::new();
141+
let _guard = local.enter();
142+
// Run the local task set.
143+
144+
let join = task::spawn_local(async { true });
145+
let rt = runtime::Builder::new_current_thread()
146+
.enable_all()
147+
.build()
148+
.unwrap();
149+
local.block_on(&rt, async move {
150+
assert!(join.await.unwrap());
151+
});
152+
}
138153

139154
#[cfg(not(target_os = "wasi"))] // Wasi doesn't support panic recovery
140155
#[test]

0 commit comments

Comments
 (0)