Skip to content

Commit 8eb8ad5

Browse files
committed
await tasks to cancel (#6696)
# Objective - Fixes #6603 ## Solution - `Task`s will cancel when dropped, but wait until they return Pending before they actually get canceled. That means that if a task panics, it's possible for that error to get propagated to the scope and the scope gets dropped, while scoped tasks in other threads are still running. This is a big problem since scoped task can hold life-timed values that are dropped as the scope is dropped leading to UB. --- ## Changelog - changed `Scope` to use `FallibleTask` and await the cancellation of all remaining tasks when it's dropped.
1 parent 3433a7b commit 8eb8ad5

File tree

1 file changed

+24
-26
lines changed

1 file changed

+24
-26
lines changed

crates/bevy_tasks/src/task_pool.rs

+24-26
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ use std::{
22
future::Future,
33
marker::PhantomData,
44
mem,
5-
pin::Pin,
65
sync::Arc,
76
thread::{self, JoinHandle},
87
};
98

9+
use async_task::FallibleTask;
1010
use concurrent_queue::ConcurrentQueue;
1111
use futures_lite::{future, pin, FutureExt};
1212

@@ -248,8 +248,8 @@ impl TaskPool {
248248
let task_scope_executor = &async_executor::Executor::default();
249249
let task_scope_executor: &'env async_executor::Executor =
250250
unsafe { mem::transmute(task_scope_executor) };
251-
let spawned: ConcurrentQueue<async_executor::Task<T>> = ConcurrentQueue::unbounded();
252-
let spawned_ref: &'env ConcurrentQueue<async_executor::Task<T>> =
251+
let spawned: ConcurrentQueue<FallibleTask<T>> = ConcurrentQueue::unbounded();
252+
let spawned_ref: &'env ConcurrentQueue<FallibleTask<T>> =
253253
unsafe { mem::transmute(&spawned) };
254254

255255
let scope = Scope {
@@ -267,10 +267,10 @@ impl TaskPool {
267267
if spawned.is_empty() {
268268
Vec::new()
269269
} else {
270-
let get_results = async move {
271-
let mut results = Vec::with_capacity(spawned.len());
272-
while let Ok(task) = spawned.pop() {
273-
results.push(task.await);
270+
let get_results = async {
271+
let mut results = Vec::with_capacity(spawned_ref.len());
272+
while let Ok(task) = spawned_ref.pop() {
273+
results.push(task.await.unwrap());
274274
}
275275

276276
results
@@ -279,23 +279,8 @@ impl TaskPool {
279279
// Pin the futures on the stack.
280280
pin!(get_results);
281281

282-
// SAFETY: This function blocks until all futures complete, so we do not read/write
283-
// the data from futures outside of the 'scope lifetime. However,
284-
// rust has no way of knowing this so we must convert to 'static
285-
// here to appease the compiler as it is unable to validate safety.
286-
let get_results: Pin<&mut (dyn Future<Output = Vec<T>> + 'static + Send)> = get_results;
287-
let get_results: Pin<&'static mut (dyn Future<Output = Vec<T>> + 'static + Send)> =
288-
unsafe { mem::transmute(get_results) };
289-
290-
// The thread that calls scope() will participate in driving tasks in the pool
291-
// forward until the tasks that are spawned by this scope() call
292-
// complete. (If the caller of scope() happens to be a thread in
293-
// this thread pool, and we only have one thread in the pool, then
294-
// simply calling future::block_on(spawned) would deadlock.)
295-
let mut spawned = task_scope_executor.spawn(get_results);
296-
297282
loop {
298-
if let Some(result) = future::block_on(future::poll_once(&mut spawned)) {
283+
if let Some(result) = future::block_on(future::poll_once(&mut get_results)) {
299284
break result;
300285
};
301286

@@ -378,7 +363,7 @@ impl Drop for TaskPool {
378363
pub struct Scope<'scope, 'env: 'scope, T> {
379364
executor: &'scope async_executor::Executor<'scope>,
380365
task_scope_executor: &'scope async_executor::Executor<'scope>,
381-
spawned: &'scope ConcurrentQueue<async_executor::Task<T>>,
366+
spawned: &'scope ConcurrentQueue<FallibleTask<T>>,
382367
// make `Scope` invariant over 'scope and 'env
383368
scope: PhantomData<&'scope mut &'scope ()>,
384369
env: PhantomData<&'env mut &'env ()>,
@@ -394,7 +379,7 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
394379
///
395380
/// For more information, see [`TaskPool::scope`].
396381
pub fn spawn<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
397-
let task = self.executor.spawn(f);
382+
let task = self.executor.spawn(f).fallible();
398383
// ConcurrentQueue only errors when closed or full, but we never
399384
// close and use an unbouded queue, so it is safe to unwrap
400385
self.spawned.push(task).unwrap();
@@ -407,13 +392,26 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
407392
///
408393
/// For more information, see [`TaskPool::scope`].
409394
pub fn spawn_on_scope<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
410-
let task = self.task_scope_executor.spawn(f);
395+
let task = self.task_scope_executor.spawn(f).fallible();
411396
// ConcurrentQueue only errors when closed or full, but we never
412397
// close and use an unbouded queue, so it is safe to unwrap
413398
self.spawned.push(task).unwrap();
414399
}
415400
}
416401

402+
impl<'scope, 'env, T> Drop for Scope<'scope, 'env, T>
403+
where
404+
T: 'scope,
405+
{
406+
fn drop(&mut self) {
407+
future::block_on(async {
408+
while let Ok(task) = self.spawned.pop() {
409+
task.cancel().await;
410+
}
411+
});
412+
}
413+
}
414+
417415
#[cfg(test)]
418416
#[allow(clippy::disallowed_types)]
419417
mod tests {

0 commit comments

Comments
 (0)