Skip to content

Commit b4edfab

Browse files
committed
make a safer abstraction for the main thread executor
1 parent 7383257 commit b4edfab

File tree

5 files changed

+96
-51
lines changed

5 files changed

+96
-51
lines changed

crates/bevy_ecs/src/schedule/executor_parallel.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ impl ParallelExecutor {
236236
if system_data.is_send {
237237
scope.spawn(task);
238238
} else {
239-
scope.spawn_on_scope(task);
239+
scope.spawn_on_main(task);
240240
}
241241

242242
#[cfg(test)]
@@ -271,7 +271,7 @@ impl ParallelExecutor {
271271
if system_data.is_send {
272272
scope.spawn(task);
273273
} else {
274-
scope.spawn_on_scope(task);
274+
scope.spawn_on_main(task);
275275
}
276276
}
277277
}

crates/bevy_tasks/src/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ pub use single_threaded_task_pool::{Scope, TaskPool, TaskPoolBuilder};
2020
mod usages;
2121
#[cfg(not(target_arch = "wasm32"))]
2222
pub use usages::tick_global_task_pools_on_main_thread;
23-
pub use usages::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, MainThreadExecutor};
23+
pub use usages::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool};
24+
25+
mod main_thread_executor;
26+
pub use main_thread_executor::MainThreadExecutor;
2427

2528
mod iter;
2629
pub use iter::ParallelIterator;
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
use std::{marker::PhantomData, sync::Arc};
2+
3+
use async_executor::{Executor, Task};
4+
use futures_lite::Future;
5+
use is_main_thread::is_main_thread;
6+
use once_cell::sync::OnceCell;
7+
8+
static MAIN_THREAD_EXECUTOR: OnceCell<MainThreadExecutor> = OnceCell::new();
9+
10+
/// Use to access the global main thread executor. Be aware that the main thread executor
11+
/// only makes progress when it is ticked. This normally happens in `[TaskPool::scope]`.
12+
#[derive(Debug)]
13+
pub struct MainThreadExecutor(Arc<Executor<'static>>);
14+
15+
impl MainThreadExecutor {
16+
/// Initializes the global `[MainThreadExecutor]` instance.
17+
pub fn init() -> &'static Self {
18+
MAIN_THREAD_EXECUTOR.get_or_init(|| Self(Arc::new(Executor::new())))
19+
}
20+
21+
/// Gets the global [`MainThreadExecutor`] instance.
22+
///
23+
/// # Panics
24+
/// Panics if no executor has been initialized yet.
25+
pub fn get() -> &'static Self {
26+
MAIN_THREAD_EXECUTOR.get().expect(
27+
"A MainThreadExecutor has not been initialize yet. Please call \
28+
MainThreadExecutor::init beforehand",
29+
)
30+
}
31+
32+
/// Gets the `[MainThreadSpawner]` for the global main thread executor.
33+
/// Use this to spawn tasks on the main thread.
34+
pub fn spawner(&self) -> MainThreadSpawner<'static> {
35+
MainThreadSpawner(self.0.clone())
36+
}
37+
38+
/// Gets the `[MainThreadTicker]` for the global main thread executor.
39+
/// Use this to tick the main thread executor.
40+
/// Returns None if called on not the main thread.
41+
pub fn ticker(&self) -> Option<MainThreadTicker> {
42+
if let Some(is_main) = is_main_thread() {
43+
if is_main {
44+
return Some(MainThreadTicker {
45+
executor: self.0.clone(),
46+
_marker: PhantomData::default(),
47+
});
48+
}
49+
}
50+
None
51+
}
52+
}
53+
54+
#[derive(Debug)]
55+
pub struct MainThreadSpawner<'a>(Arc<Executor<'a>>);
56+
impl<'a> MainThreadSpawner<'a> {
57+
/// Spawn a task on the main thread
58+
pub fn spawn<T: Send + 'a>(&self, future: impl Future<Output = T> + Send + 'a) -> Task<T> {
59+
self.0.spawn(future)
60+
}
61+
}
62+
63+
#[derive(Debug)]
64+
pub struct MainThreadTicker {
65+
executor: Arc<Executor<'static>>,
66+
// make type not send or sync
67+
_marker: PhantomData<*const ()>,
68+
}
69+
impl MainThreadTicker {
70+
/// Tick the main thread executor.
71+
/// This needs to be called manually on the main thread if a `[TaskPool::scope]` is not active
72+
pub fn tick<'a>(&'a self) -> impl Future<Output = ()> + 'a {
73+
self.executor.tick()
74+
}
75+
}

crates/bevy_tasks/src/task_pool.rs

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@ use std::{
88

99
use concurrent_queue::ConcurrentQueue;
1010
use futures_lite::{future, FutureExt};
11-
use is_main_thread::is_main_thread;
1211

13-
use crate::MainThreadExecutor;
1412
use crate::Task;
13+
use crate::{main_thread_executor::MainThreadSpawner, MainThreadExecutor};
1514

1615
/// Used to create a [`TaskPool`]
1716
#[derive(Debug, Default, Clone)]
@@ -246,16 +245,16 @@ impl TaskPool {
246245
// transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety.
247246
let executor: &async_executor::Executor = &self.executor;
248247
let executor: &'env async_executor::Executor = unsafe { mem::transmute(executor) };
249-
let task_scope_executor = MainThreadExecutor::init();
250-
let task_scope_executor: &'env async_executor::Executor =
251-
unsafe { mem::transmute(task_scope_executor) };
248+
let main_thread_spawner = MainThreadExecutor::init().spawner();
249+
let main_thread_spawner: MainThreadSpawner<'env> =
250+
unsafe { mem::transmute(main_thread_spawner) };
252251
let spawned: ConcurrentQueue<async_executor::Task<T>> = ConcurrentQueue::unbounded();
253252
let spawned_ref: &'env ConcurrentQueue<async_executor::Task<T>> =
254253
unsafe { mem::transmute(&spawned) };
255254

256255
let scope = Scope {
257256
executor,
258-
task_scope_executor,
257+
main_thread_spawner,
259258
spawned: spawned_ref,
260259
scope: PhantomData,
261260
env: PhantomData,
@@ -278,20 +277,10 @@ impl TaskPool {
278277
results
279278
};
280279

281-
let is_main = if let Some(is_main) = is_main_thread() {
282-
is_main
283-
} else {
284-
false
285-
};
286-
287-
if is_main {
280+
if let Some(main_thread_ticker) = MainThreadExecutor::get().ticker() {
288281
let tick_forever = async move {
289282
loop {
290-
if let Some(is_main) = is_main_thread() {
291-
if is_main {
292-
task_scope_executor.tick().await;
293-
}
294-
}
283+
main_thread_ticker.tick().await;
295284
}
296285
};
297286

@@ -372,7 +361,7 @@ impl Drop for TaskPool {
372361
#[derive(Debug)]
373362
pub struct Scope<'scope, 'env: 'scope, T> {
374363
executor: &'scope async_executor::Executor<'scope>,
375-
task_scope_executor: &'scope async_executor::Executor<'scope>,
364+
main_thread_spawner: MainThreadSpawner<'scope>,
376365
spawned: &'scope ConcurrentQueue<async_executor::Task<T>>,
377366
// make `Scope` invariant over 'scope and 'env
378367
scope: PhantomData<&'scope mut &'scope ()>,
@@ -401,8 +390,10 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
401390
/// [`Scope::spawn`] instead, unless the provided future needs to run on the scope's thread.
402391
///
403392
/// For more information, see [`TaskPool::scope`].
404-
pub fn spawn_on_scope<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
405-
let task = self.task_scope_executor.spawn(f);
393+
pub fn spawn_on_main<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
394+
let main_thread_spawner: &MainThreadSpawner<'scope> =
395+
unsafe { mem::transmute(&self.main_thread_spawner) };
396+
let task = main_thread_spawner.spawn(f);
406397
// ConcurrentQueue only errors when closed or full, but we never
407398
// close and use an unbouded queue, so it is safe to unwrap
408399
self.spawned.push(task).unwrap();
@@ -473,7 +464,7 @@ mod tests {
473464
});
474465
} else {
475466
let count_clone = local_count.clone();
476-
scope.spawn_on_scope(async move {
467+
scope.spawn_on_main(async move {
477468
if *foo != 42 {
478469
panic!("not 42!?!?")
479470
} else {
@@ -514,7 +505,7 @@ mod tests {
514505
});
515506
let spawner = std::thread::current().id();
516507
let inner_count_clone = count_clone.clone();
517-
scope.spawn_on_scope(async move {
508+
scope.spawn_on_main(async move {
518509
inner_count_clone.fetch_add(1, Ordering::Release);
519510
if std::thread::current().id() != spawner {
520511
// NOTE: This check is using an atomic rather than simply panicing the
@@ -589,7 +580,7 @@ mod tests {
589580
inner_count_clone.fetch_add(1, Ordering::Release);
590581

591582
// spawning on the scope from another thread runs the futures on the scope's thread
592-
scope.spawn_on_scope(async move {
583+
scope.spawn_on_main(async move {
593584
inner_count_clone.fetch_add(1, Ordering::Release);
594585
if std::thread::current().id() != spawner {
595586
// NOTE: This check is using an atomic rather than simply panicing the

crates/bevy_tasks/src/usages.rs

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ use std::ops::Deref;
1717
static COMPUTE_TASK_POOL: OnceCell<ComputeTaskPool> = OnceCell::new();
1818
static ASYNC_COMPUTE_TASK_POOL: OnceCell<AsyncComputeTaskPool> = OnceCell::new();
1919
static IO_TASK_POOL: OnceCell<IoTaskPool> = OnceCell::new();
20-
static MAIN_THREAD_EXECUTOR: OnceCell<MainThreadExecutor> = OnceCell::new();
2120

2221
/// A newtype for a task pool for CPU-intensive work that must be completed to deliver the next
2322
/// frame
@@ -111,29 +110,6 @@ impl Deref for IoTaskPool {
111110
}
112111
}
113112

114-
pub struct MainThreadExecutor(async_executor::Executor<'static>);
115-
116-
impl MainThreadExecutor {
117-
pub fn init() -> &'static Self {
118-
MAIN_THREAD_EXECUTOR.get_or_init(|| Self(async_executor::Executor::new()))
119-
}
120-
121-
pub fn get() -> &'static Self {
122-
MAIN_THREAD_EXECUTOR.get().expect(
123-
"A MainThreadExecutor has not been initialize yet. Please call \
124-
MainThreadExecutor::init beforehand",
125-
)
126-
}
127-
}
128-
129-
impl Deref for MainThreadExecutor {
130-
type Target = async_executor::Executor<'static>;
131-
132-
fn deref(&self) -> &Self::Target {
133-
&self.0
134-
}
135-
}
136-
137113
/// Used by `bevy_core` to tick the global tasks pools on the main thread.
138114
/// This will run a maximum of 100 local tasks per executor per call to this function.
139115
#[cfg(not(target_arch = "wasm32"))]

0 commit comments

Comments
 (0)