Skip to content

Consolidate together Bevy's TaskPools [adopted] #18163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 31 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6bb352c
Consolidate together Bevy's TaskPools
james7132 Feb 24, 2024
803e74f
Formatting
james7132 Feb 24, 2024
29e5fab
Backticks
james7132 Feb 24, 2024
05007b0
Apply suggestions from code review
james7132 Feb 24, 2024
16c9816
Add a spawn_blocking_async
james7132 Feb 25, 2024
332c98b
Add configuration for the number of blocking threads.
james7132 Feb 25, 2024
589cab7
Fix build
james7132 Feb 25, 2024
7d85100
Apply suggestions from code review
james7132 Feb 25, 2024
ac13d61
Fix Wasm and document platform specific behavior.
james7132 Feb 26, 2024
e08a7fc
Formatting
james7132 Feb 26, 2024
b0cb7c6
Fix warning
james7132 Feb 26, 2024
25706c7
Fix toml formatting
james7132 Feb 27, 2024
f3ef65c
Provide more disambiguation for num_blocking_threads.
james7132 Feb 27, 2024
9bf32a5
Correct documentation on spawn_blocking_async
james7132 Feb 27, 2024
4896e19
Formatting
james7132 Feb 27, 2024
95b7435
Fix typos
james7132 Feb 27, 2024
274fe31
Fix more typos
james7132 Feb 27, 2024
207075e
Merge branch 'main' into task-pool-consolidation
james7132 Apr 17, 2024
9a64617
Remove reference to IO task pool.
james7132 Apr 17, 2024
1bc4ff4
Toml formatting
james7132 Apr 17, 2024
c7e7c60
Merge remote-tracking branch 'upstream/main' into task-pool-consolida…
hymm Mar 5, 2025
8161b3e
fix merge issues
hymm Mar 5, 2025
72e6719
add spawn blocking to scope
hymm Mar 5, 2025
fa9ec1d
fix task pool plugin rebase
hymm Mar 5, 2025
60e7762
fmt
hymm Mar 5, 2025
d0a19e9
temporarily comment out more code
hymm Mar 5, 2025
d3bb78e
add reason
hymm Mar 5, 2025
d8828e4
add maybesend and maybesync to single threaded
hymm Mar 5, 2025
e1e8533
format tasks cargo.toml
hymm Mar 5, 2025
78b78ec
fix doc links
hymm Mar 5, 2025
61335ae
ci
hymm Mar 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 36 additions & 173 deletions crates/bevy_app/src/task_pool_plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{App, Plugin};

use alloc::string::ToString;
use bevy_platform_support::sync::Arc;
use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder};
use bevy_tasks::{ComputeTaskPool, TaskPoolBuilder};
use core::{fmt::Debug, marker::PhantomData};
use log::trace;

Expand All @@ -12,7 +12,7 @@ use {crate::Last, bevy_ecs::prelude::NonSend};
#[cfg(not(target_arch = "wasm32"))]
use bevy_tasks::tick_global_task_pools_on_main_thread;

/// Setup of default task pools: [`AsyncComputeTaskPool`], [`ComputeTaskPool`], [`IoTaskPool`].
/// Setup of default task pools: [`ComputeTaskPool`].
#[derive(Default)]
pub struct TaskPoolPlugin {
/// Options for the [`TaskPool`](bevy_tasks::TaskPool) created at application start.
Expand Down Expand Up @@ -40,17 +40,16 @@ fn tick_global_task_pools(_main_thread_marker: Option<NonSend<NonSendMarker>>) {
tick_global_task_pools_on_main_thread();
}

/// Defines a simple way to determine how many threads to use given the number of remaining cores
/// and number of total cores
/// Helper for configuring and creating the default task pools. For end-users who want full control,
/// set up [`TaskPoolPlugin`]
#[derive(Clone)]
pub struct TaskPoolThreadAssignmentPolicy {
/// Force using at least this many threads
pub min_threads: usize,
/// Under no circumstance use more than this many threads for this pool
pub max_threads: usize,
/// Target using this percentage of total cores, clamped by `min_threads` and `max_threads`. It is
/// permitted to use 1.0 to try to use all remaining threads
pub percent: f32,
pub struct TaskPoolOptions {
/// If the number of physical cores is less than `min_total_threads`, force using
/// `min_total_threads`
pub min_total_threads: usize,
/// If the number of physical cores is greater than `max_total_threads`, force using
/// `max_total_threads`
pub max_total_threads: usize,
/// Callback that is invoked once for every created thread as it starts.
/// This configuration will be ignored under wasm platform.
pub on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
Expand All @@ -59,91 +58,25 @@ pub struct TaskPoolThreadAssignmentPolicy {
pub on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
}

impl Debug for TaskPoolThreadAssignmentPolicy {
impl Debug for TaskPoolOptions {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("TaskPoolThreadAssignmentPolicy")
.field("min_threads", &self.min_threads)
.field("max_threads", &self.max_threads)
.field("percent", &self.percent)
f.debug_struct("TaskPoolOptions")
.field("min_total_threads", &self.min_total_threads)
.field("max_total_threads", &self.max_total_threads)
.field("on_thread_spawn", &self.on_thread_spawn.is_some())
.field("on_thread_destroy", &self.on_thread_destroy.is_some())
.finish()
}
}

impl TaskPoolThreadAssignmentPolicy {
/// Determine the number of threads to use for this task pool
fn get_number_of_threads(&self, remaining_threads: usize, total_threads: usize) -> usize {
assert!(self.percent >= 0.0);
let proportion = total_threads as f32 * self.percent;
let mut desired = proportion as usize;

// Equivalent to round() for positive floats without libm requirement for
// no_std compatibility
if proportion - desired as f32 >= 0.5 {
desired += 1;
}

// Limit ourselves to the number of cores available
desired = desired.min(remaining_threads);

// Clamp by min_threads, max_threads. (This may result in us using more threads than are
// available, this is intended. An example case where this might happen is a device with
// <= 2 threads.
desired.clamp(self.min_threads, self.max_threads)
}
}

/// Helper for configuring and creating the default task pools. For end-users who want full control,
/// set up [`TaskPoolPlugin`]
#[derive(Clone, Debug)]
pub struct TaskPoolOptions {
/// If the number of physical cores is less than `min_total_threads`, force using
/// `min_total_threads`
pub min_total_threads: usize,
/// If the number of physical cores is greater than `max_total_threads`, force using
/// `max_total_threads`
pub max_total_threads: usize,

/// Used to determine number of IO threads to allocate
pub io: TaskPoolThreadAssignmentPolicy,
/// Used to determine number of async compute threads to allocate
pub async_compute: TaskPoolThreadAssignmentPolicy,
/// Used to determine number of compute threads to allocate
pub compute: TaskPoolThreadAssignmentPolicy,
}

impl Default for TaskPoolOptions {
fn default() -> Self {
TaskPoolOptions {
// By default, use however many cores are available on the system
min_total_threads: 1,
max_total_threads: usize::MAX,

// Use 25% of cores for IO, at least 1, no more than 4
io: TaskPoolThreadAssignmentPolicy {
min_threads: 1,
max_threads: 4,
percent: 0.25,
on_thread_spawn: None,
on_thread_destroy: None,
},

// Use 25% of cores for async compute, at least 1, no more than 4
async_compute: TaskPoolThreadAssignmentPolicy {
min_threads: 1,
max_threads: 4,
percent: 0.25,
on_thread_spawn: None,
on_thread_destroy: None,
},

// Use all remaining cores for compute (at least 1)
compute: TaskPoolThreadAssignmentPolicy {
min_threads: 1,
max_threads: usize::MAX,
percent: 1.0, // This 1.0 here means "whatever is left over"
on_thread_spawn: None,
on_thread_destroy: None,
},
on_thread_spawn: None,
on_thread_destroy: None,
}
}
}
Expand All @@ -164,109 +97,39 @@ impl TaskPoolOptions {
.clamp(self.min_total_threads, self.max_total_threads);
trace!("Assigning {} cores to default task pools", total_threads);

let mut remaining_threads = total_threads;

{
// Determine the number of IO threads we will use
let io_threads = self
.io
.get_number_of_threads(remaining_threads, total_threads);
ComputeTaskPool::get_or_init(|| {
#[cfg_attr(target_arch = "wasm32", expect(unused_mut))]
let mut builder = TaskPoolBuilder::default()
.num_threads(total_threads)
.thread_name("Compute Task Pool".to_string());

trace!("IO Threads: {}", io_threads);
remaining_threads = remaining_threads.saturating_sub(io_threads);

IoTaskPool::get_or_init(|| {
#[cfg_attr(target_arch = "wasm32", expect(unused_mut))]
let mut builder = TaskPoolBuilder::default()
.num_threads(io_threads)
.thread_name("IO Task Pool".to_string());

#[cfg(not(target_arch = "wasm32"))]
{
if let Some(f) = self.io.on_thread_spawn.clone() {
builder = builder.on_thread_spawn(move || f());
}
if let Some(f) = self.io.on_thread_destroy.clone() {
builder = builder.on_thread_destroy(move || f());
}
#[cfg(not(target_arch = "wasm32"))]
{
if let Some(f) = self.on_thread_spawn.clone() {
builder = builder.on_thread_spawn(move || f());
}

builder.build()
});
}

{
// Determine the number of async compute threads we will use
let async_compute_threads = self
.async_compute
.get_number_of_threads(remaining_threads, total_threads);

trace!("Async Compute Threads: {}", async_compute_threads);
remaining_threads = remaining_threads.saturating_sub(async_compute_threads);

AsyncComputeTaskPool::get_or_init(|| {
#[cfg_attr(target_arch = "wasm32", expect(unused_mut))]
let mut builder = TaskPoolBuilder::default()
.num_threads(async_compute_threads)
.thread_name("Async Compute Task Pool".to_string());

#[cfg(not(target_arch = "wasm32"))]
{
if let Some(f) = self.async_compute.on_thread_spawn.clone() {
builder = builder.on_thread_spawn(move || f());
}
if let Some(f) = self.async_compute.on_thread_destroy.clone() {
builder = builder.on_thread_destroy(move || f());
}
}

builder.build()
});
}

{
// Determine the number of compute threads we will use
// This is intentionally last so that an end user can specify 1.0 as the percent
let compute_threads = self
.compute
.get_number_of_threads(remaining_threads, total_threads);

trace!("Compute Threads: {}", compute_threads);

ComputeTaskPool::get_or_init(|| {
#[cfg_attr(target_arch = "wasm32", expect(unused_mut))]
let mut builder = TaskPoolBuilder::default()
.num_threads(compute_threads)
.thread_name("Compute Task Pool".to_string());

#[cfg(not(target_arch = "wasm32"))]
{
if let Some(f) = self.compute.on_thread_spawn.clone() {
builder = builder.on_thread_spawn(move || f());
}
if let Some(f) = self.compute.on_thread_destroy.clone() {
builder = builder.on_thread_destroy(move || f());
}
if let Some(f) = self.on_thread_destroy.clone() {
builder = builder.on_thread_destroy(move || f());
}
}

builder.build()
});
}
builder.build()
});
}
}

#[cfg(test)]
mod tests {
use super::*;
use bevy_tasks::prelude::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool};
use bevy_tasks::prelude::ComputeTaskPool;

#[test]
fn runs_spawn_local_tasks() {
let mut app = App::new();
app.add_plugins(TaskPoolPlugin::default());

let (async_tx, async_rx) = crossbeam_channel::unbounded();
AsyncComputeTaskPool::get()
ComputeTaskPool::get()
.spawn_local(async move {
async_tx.send(()).unwrap();
})
Expand All @@ -280,7 +143,7 @@ mod tests {
.detach();

let (io_tx, io_rx) = crossbeam_channel::unbounded();
IoTaskPool::get()
ComputeTaskPool::get()
.spawn_local(async move {
io_tx.send(()).unwrap();
})
Expand Down
8 changes: 4 additions & 4 deletions crates/bevy_asset/src/processor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ use crate::{
use alloc::{borrow::ToOwned, boxed::Box, collections::VecDeque, sync::Arc, vec, vec::Vec};
use bevy_ecs::prelude::*;
use bevy_platform_support::collections::{HashMap, HashSet};
use bevy_tasks::IoTaskPool;
use bevy_tasks::ComputeTaskPool;
use futures_io::ErrorKind;
use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt};
use parking_lot::RwLock;
Expand Down Expand Up @@ -218,7 +218,7 @@ impl AssetProcessor {
pub fn process_assets(&self) {
let start_time = std::time::Instant::now();
debug!("Processing Assets");
IoTaskPool::get().scope(|scope| {
ComputeTaskPool::get().scope(|scope| {
scope.spawn(async move {
self.initialize().await.unwrap();
for source in self.sources().iter_processed() {
Expand Down Expand Up @@ -368,7 +368,7 @@ impl AssetProcessor {
#[cfg(any(target_arch = "wasm32", not(feature = "multi_threaded")))]
error!("AddFolder event cannot be handled in single threaded mode (or Wasm) yet.");
#[cfg(all(not(target_arch = "wasm32"), feature = "multi_threaded"))]
IoTaskPool::get().scope(|scope| {
ComputeTaskPool::get().scope(|scope| {
scope.spawn(async move {
self.process_assets_internal(scope, source, path)
.await
Expand Down Expand Up @@ -510,7 +510,7 @@ impl AssetProcessor {
loop {
let mut check_reprocess_queue =
core::mem::take(&mut self.data.asset_infos.write().await.check_reprocess_queue);
IoTaskPool::get().scope(|scope| {
ComputeTaskPool::get().scope(|scope| {
for path in check_reprocess_queue.drain(..) {
let processor = self.clone();
let source = self.get_source(path.source()).unwrap();
Expand Down
4 changes: 2 additions & 2 deletions crates/bevy_asset/src/server/loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use async_broadcast::RecvError;
use bevy_platform_support::collections::HashMap;
use bevy_tasks::IoTaskPool;
use bevy_tasks::ComputeTaskPool;
use bevy_utils::TypeIdMap;
use core::any::TypeId;
use thiserror::Error;
Expand Down Expand Up @@ -93,7 +93,7 @@ impl AssetLoaders {
match maybe_loader {
MaybeAssetLoader::Ready(_) => unreachable!(),
MaybeAssetLoader::Pending { sender, .. } => {
IoTaskPool::get()
ComputeTaskPool::get()
.spawn(async move {
let _ = sender.broadcast(loader).await;
})
Expand Down
12 changes: 6 additions & 6 deletions crates/bevy_asset/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use alloc::{
use atomicow::CowArc;
use bevy_ecs::prelude::*;
use bevy_platform_support::collections::HashSet;
use bevy_tasks::IoTaskPool;
use bevy_tasks::ComputeTaskPool;
use core::{any::TypeId, future::Future, panic::AssertUnwindSafe, task::Poll};
use crossbeam_channel::{Receiver, Sender};
use either::Either;
Expand Down Expand Up @@ -426,7 +426,7 @@ impl AssetServer {

let owned_handle = handle.clone();
let server = self.clone();
let task = IoTaskPool::get().spawn(async move {
let task = ComputeTaskPool::get().spawn(async move {
if let Err(err) = server
.load_internal(Some(owned_handle), path, false, None)
.await
Expand Down Expand Up @@ -487,7 +487,7 @@ impl AssetServer {
let id = handle.id().untyped();

let server = self.clone();
let task = IoTaskPool::get().spawn(async move {
let task = ComputeTaskPool::get().spawn(async move {
let path_clone = path.clone();
match server.load_untyped_async(path).await {
Ok(handle) => server.send_asset_event(InternalAssetEvent::Loaded {
Expand Down Expand Up @@ -716,7 +716,7 @@ impl AssetServer {
pub fn reload<'a>(&self, path: impl Into<AssetPath<'a>>) {
let server = self.clone();
let path = path.into().into_owned();
IoTaskPool::get()
ComputeTaskPool::get()
.spawn(async move {
let mut reloaded = false;

Expand Down Expand Up @@ -810,7 +810,7 @@ impl AssetServer {

let event_sender = self.data.asset_event_sender.clone();

let task = IoTaskPool::get().spawn(async move {
let task = ComputeTaskPool::get().spawn(async move {
match future.await {
Ok(asset) => {
let loaded_asset = LoadedAsset::new_with_dependencies(asset).into();
Expand Down Expand Up @@ -913,7 +913,7 @@ impl AssetServer {

let path = path.into_owned();
let server = self.clone();
IoTaskPool::get()
ComputeTaskPool::get()
.spawn(async move {
let Ok(source) = server.get_source(path.source()) else {
error!(
Expand Down
Loading