Skip to content

Commit

Permalink
Added one_shot feature
Browse files Browse the repository at this point in the history
  • Loading branch information
Kjolnyr committed Apr 16, 2023
1 parent 35537e2 commit 9e1f9de
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 28 deletions.
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,7 @@ futures-lite = "1.13.0"
name = "simple"

[[example]]
name = "multi_pass"
name = "multi_pass"

[[example]]
name = "one_shot"
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ fn my_system(
}
```

(see [simple.rs](https://github.com/kjolnyr/bevy_app_compute/tree/dev/examples/simple.rs))

### Multiple passes

You can have multiple passes without having to copy data back to the CPU in between:
Expand All @@ -129,12 +131,55 @@ let worker = AppComputeWorkerBuilder::new(world)

```

(see [multi_pass.rs](https://github.com/kjolnyr/bevy_app_compute/tree/dev/examples/multi_pass.rs))

### One shot computes

You can configure your worker to execute only when requested:

```rust
let worker = AppComputeWorkerBuilder::new(world)
.add_uniform("uni", &5.)
.add_staging("values", &[1., 2., 3., 4.])
.add_pass::<SimpleShader>([4, 1, 1], &["uni", "values"])

// This `one_shot()` function will configure your worker accordingly
.one_shot()
.build();

```

Then, you can call `execute()` on your worker when you are ready to execute it:

```rust
// Execute it only when the left mouse button is pressed.
fn on_click_compute(
buttons: Res<Input<MouseButton>>,
mut compute_worker: ResMut<AppComputeWorker<SimpleComputeWorker>>
) {
if !buttons.just_pressed(MouseButton::Left) { return; }

compute_worker.execute();
}
```

It will run at the end of the current frame, and you'll be able to read the data in the next frame.

(see [one_shot.rs](https://github.com/kjolnyr/bevy_app_compute/tree/dev/examples/one_shot.rs))


## Examples

See [examples](https://github.com/kjolnyr/bevy_app_compute/tree/main/examples)


## Features being worked upon

- Ability to read/write between compute passes.
- add more options to the api, like deciding `BufferUsages` or size of buffers.
- Optimization. Right now the code is a complete mess.
- Tests. This badly needs tests.

## Bevy version mapping

|Bevy|bevy_app_compute|
Expand Down
2 changes: 1 addition & 1 deletion examples/multi_pass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ fn main() {
}

fn test(compute_worker: Res<AppComputeWorker<SimpleComputeWorker>>) {
if !compute_worker.available() {
if !compute_worker.ready() {
return;
};

Expand Down
60 changes: 60 additions & 0 deletions examples/one_shot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use bevy::{core::cast_slice, prelude::*, reflect::TypeUuid, render::render_resource::ShaderRef};
use bevy_app_compute::prelude::*;

#[derive(TypeUuid)]
#[uuid = "2545ae14-a9bc-4f03-9ea4-4eb43d1075a7"]
struct SimpleShader;

impl ComputeShader for SimpleShader {
fn shader() -> ShaderRef {
"shaders/simple.wgsl".into()
}
}

#[derive(Resource)]
struct SimpleComputeWorker;

impl ComputeWorker for SimpleComputeWorker {
fn build(world: &mut World) -> AppComputeWorker<Self> {
let worker = AppComputeWorkerBuilder::new(world)
.add_uniform("uni", &5.)
.add_staging("values", &[1., 2., 3., 4.])
.add_pass::<SimpleShader>([4, 1, 1], &["uni", "values"])
.one_shot()
.build();

worker
}
}

fn main() {
App::new()
.add_plugins(DefaultPlugins)
.add_plugin(AppComputePlugin)
.add_plugin(AppComputeWorkerPlugin::<SimpleComputeWorker>::default())
.add_system(on_click_compute)
.add_system(read_data)
.run();
}

fn on_click_compute(
buttons: Res<Input<MouseButton>>,
mut compute_worker: ResMut<AppComputeWorker<SimpleComputeWorker>>
) {
if !buttons.just_pressed(MouseButton::Left) { return; }

compute_worker.execute();
}

fn read_data(mut compute_worker: ResMut<AppComputeWorker<SimpleComputeWorker>>) {
if !compute_worker.ready() {
return;
};

let values = compute_worker.read("values");
let result: &[f32] = cast_slice(&values);

compute_worker.write("values", &result);

println!("got {:?}", result)
}
4 changes: 2 additions & 2 deletions examples/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ fn main() {
}

fn test(mut compute_worker: ResMut<AppComputeWorker<SimpleComputeWorker>>) {
if !compute_worker.available() {
if !compute_worker.ready() {
return;
};

let values = compute_worker.read("values");
let result: &[f32] = cast_slice(&values);

compute_worker.write("values", [2., 3., 4., 5.]);
compute_worker.write("values", &[2., 3., 4., 5.]);

println!("got {:?}", result)
}
10 changes: 4 additions & 6 deletions src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ impl Plugin for AppComputePlugin {
}
}


/// Plugin to initialise your [`AppComputeWorker<W>`] structs.
pub struct AppComputeWorkerPlugin<W: ComputeWorker> {
_phantom: PhantomData<W>,
Expand All @@ -40,11 +39,10 @@ impl<W: ComputeWorker> Plugin for AppComputeWorkerPlugin<W> {

app.insert_resource(worker)
.add_system(AppComputeWorker::<W>::extract_pipelines)
.add_system(AppComputeWorker::<W>::run.in_base_set(CoreSet::PostUpdate))
.add_system(
AppComputeWorker::<W>::unmap_all
.in_base_set(CoreSet::PostUpdate)
.before(AppComputeWorker::<W>::run),
.add_systems(
(AppComputeWorker::<W>::unmap_all, AppComputeWorker::<W>::run)
.chain()
.in_base_set(CoreSet::PostUpdate),
);
}
}
65 changes: 48 additions & 17 deletions src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use bevy::{
core::{cast_slice, Pod},
prelude::{Res, ResMut, Resource},
render::{
render_resource::{encase::private::WriteInto, Buffer, ComputePipeline, ShaderType},
render_resource::{
encase::{private::WriteInto, StorageBuffer},
Buffer, ComputePipeline, ShaderType,
},
renderer::{RenderDevice, RenderQueue},
},
utils::{HashMap, Uuid},
Expand All @@ -21,11 +24,18 @@ use crate::{
worker_builder::AppComputeWorkerBuilder,
};

#[derive(PartialEq, Clone, Copy)]
pub enum RunMode {
Continuous,
OneShot(bool),
}

#[derive(PartialEq)]
pub enum WorkerState {
Created,
Available,
Working,
FinishedWorking,
}

#[derive(Clone)]
Expand Down Expand Up @@ -62,6 +72,9 @@ impl StaggingBuffers {

/// Struct to manage data transfers from/to the GPU
/// it also handles the logic of your compute work.
/// By default, the run mode of the workers is set to [`RunMode::Continuous`]
/// meaning it will run every frames. If you want to run it deterministically
/// Set the run mode to [`RunMode::OneShot`]
#[derive(Resource)]
pub struct AppComputeWorker<W: ComputeWorker> {
pub(crate) state: WorkerState,
Expand All @@ -77,6 +90,7 @@ pub struct AppComputeWorker<W: ComputeWorker> {
write_requested: bool,
write_buffers_mapped: bool,
read_buffers_mapped: bool,
run_mode: RunMode,
_phantom: PhantomData<W>,
}

Expand Down Expand Up @@ -109,6 +123,7 @@ impl<W: ComputeWorker> From<&AppComputeWorkerBuilder<'_, W>> for AppComputeWorke
write_requested: false,
write_buffers_mapped: false,
read_buffers_mapped: false,
run_mode: builder.run_mode,
_phantom: PhantomData::default(),
}
}
Expand Down Expand Up @@ -241,17 +256,17 @@ impl<W: ComputeWorker> AppComputeWorker<W> {
}

/// Write data to `target` staging buffer.
pub fn write<T: ShaderType + WriteInto + Pod>(&mut self, target: &str, data: T) {
pub fn write<T: ShaderType + WriteInto>(&mut self, target: &str, data: &T) {
let staging_buffer = &self
.staging_buffers
.get(target)
.unwrap_or_else(|| panic!("Unable to find buffer {target} to write into"));

let binding = [data];
let bytes: &[u8] = cast_slice(&binding);
let mut buffer = StorageBuffer::new(Vec::new());
buffer.write::<T>(data).unwrap();

self.render_queue
.write_buffer(&staging_buffer.write, 0, &bytes);
.write_buffer(&staging_buffer.write, 0, &buffer.as_ref());
self.write_requested = true;
}

Expand All @@ -264,28 +279,39 @@ impl<W: ComputeWorker> AppComputeWorker<W> {
}

fn poll(&self) -> bool {
let index = &self
let Some(index) = &self
.submission_index
.clone()
.unwrap_or_else(|| panic!("Cound't find a submission index!"));
else { return false; };

self.render_device
.wgpu_device()
.poll(wgpu::MaintainBase::WaitForSubmissionIndex(index.clone()))
}

/// Check if the worker if available for read/write.
pub fn available(&self) -> bool {
self.state == WorkerState::Available
/// Check if the worker is ready to be read from.
pub fn ready(&self) -> bool {
self.state == WorkerState::FinishedWorking
}

/// Tell the worker to execute the compute shader at the end of the current frame
pub fn execute(&mut self) {
match self.run_mode {
RunMode::Continuous => {}
RunMode::OneShot(_) => self.run_mode = RunMode::OneShot(true),
}
}


fn created(&self) -> bool {
self.state == WorkerState::Created
fn ready_to_execute(&self) -> bool {
(self.state != WorkerState::Working) && (self.run_mode != RunMode::OneShot(false))
}

pub(crate) fn run(mut worker: ResMut<Self>) {
if worker.available() || worker.created() {
if worker.ready() {
worker.state = WorkerState::Available;
}

if worker.ready_to_execute() {
if worker.write_requested {
worker.write_staging_buffers();
worker.write_requested = false;
Expand All @@ -300,18 +326,23 @@ impl<W: ComputeWorker> AppComputeWorker<W> {
worker.map_staging_buffers();
}

if worker.poll() {
worker.state = WorkerState::Available;
if worker.run_mode != RunMode::OneShot(false) && worker.poll() {
worker.state = WorkerState::FinishedWorking;
worker.command_encoder = Some(
worker
.render_device
.create_command_encoder(&CommandEncoderDescriptor { label: None }),
);

match worker.run_mode {
RunMode::Continuous => {}
RunMode::OneShot(_) => worker.run_mode = RunMode::OneShot(false),
};
}
}

pub(crate) fn unmap_all(mut worker: ResMut<Self>) {
if !worker.available() || worker.created() {
if !worker.ready_to_execute() {
return;
};

Expand Down
17 changes: 16 additions & 1 deletion src/worker_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use wgpu::{util::BufferInitDescriptor, BufferDescriptor, BufferUsages};
use crate::{
pipeline_cache::{AppPipelineCache, CachedAppComputePipelineId},
traits::{ComputeShader, ComputeWorker},
worker::{AppComputeWorker, ComputePass, StaggingBuffers},
worker::{AppComputeWorker, ComputePass, StaggingBuffers, RunMode},
};

/// A builder struct to build [`AppComputeWorker<W>`]
Expand All @@ -27,6 +27,7 @@ pub struct AppComputeWorkerBuilder<'a, W: ComputeWorker> {
pub(crate) buffers: HashMap<String, Buffer>,
pub(crate) staging_buffers: HashMap<String, StaggingBuffers>,
pub(crate) passes: Vec<ComputePass>,
pub(crate) run_mode: RunMode,
_phantom: PhantomData<W>,
}

Expand All @@ -41,6 +42,7 @@ impl<'a, W: ComputeWorker> AppComputeWorkerBuilder<'a, W> {
buffers: HashMap::default(),
staging_buffers: HashMap::default(),
passes: vec![],
run_mode: RunMode::Continuous,
_phantom: PhantomData::default(),
}
}
Expand Down Expand Up @@ -226,6 +228,19 @@ impl<'a, W: ComputeWorker> AppComputeWorkerBuilder<'a, W> {
self
}

/// The worker will run every frames.
/// This is the default mode.
pub fn continuous(&mut self) -> &mut Self {
self.run_mode = RunMode::Continuous;
self
}

/// The worker will run when requested.
pub fn one_shot(&mut self) -> &mut Self {
self.run_mode = RunMode::OneShot(false);
self
}

/// Build an [`AppComputeWorker<W>`] from this builder.
pub fn build(&self) -> AppComputeWorker<W> {
AppComputeWorker::from(self)
Expand Down

0 comments on commit 9e1f9de

Please sign in to comment.