Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support-dependent timestamping in wpgu-runner compute example #38

Merged
merged 1 commit into from
Oct 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 106 additions & 47 deletions examples/runners/wgpu/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,22 @@ async fn start_internal(options: &Options, compiled_shader_modules: CompiledShad
.await
.expect("Failed to find an appropriate adapter");

let mut required_features =
wgpu::Features::TIMESTAMP_QUERY | wgpu::Features::TIMESTAMP_QUERY_INSIDE_PASSES;
// Timestamping may not be supported
let timestamping = adapter.features().contains(wgpu::Features::TIMESTAMP_QUERY)
&& adapter
.features()
.contains(wgpu::Features::TIMESTAMP_QUERY_INSIDE_PASSES);

let mut required_features = if timestamping {
wgpu::Features::TIMESTAMP_QUERY | wgpu::Features::TIMESTAMP_QUERY_INSIDE_PASSES
} else {
wgpu::Features::empty()
};
if !timestamping {
eprintln!(
"Adapter reports that timestamping is not supported - no timing information will be available"
);
}
if options.force_spirv_passthru {
required_features |= wgpu::Features::SPIRV_SHADER_PASSTHROUGH;
}
Expand All @@ -43,8 +57,11 @@ async fn start_internal(options: &Options, compiled_shader_modules: CompiledShad
drop(instance);
drop(adapter);

let timestamp_period = queue.get_timestamp_period();

let timestamp_period: Option<f32> = if timestamping {
Some(queue.get_timestamp_period())
} else {
None
};
let entry_point = "main_cs";

// FIXME(eddyb) automate this decision by default.
Expand Down Expand Up @@ -112,20 +129,26 @@ async fn start_internal(options: &Options, compiled_shader_modules: CompiledShad
| wgpu::BufferUsages::COPY_SRC,
});

let timestamp_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Timestamps buffer"),
size: 16,
usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let (timestamp_buffer, timestamp_readback_buffer) = if timestamping {
let timestamp_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Timestamps buffer"),
size: 16,
usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});

let timestamp_readback_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 16,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: true,
});
timestamp_readback_buffer.unmap();
let timestamp_readback_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 16,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: true,
});
timestamp_readback_buffer.unmap();

(Some(timestamp_buffer), Some(timestamp_readback_buffer))
} else {
(None, None)
};

let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
Expand All @@ -136,11 +159,15 @@ async fn start_internal(options: &Options, compiled_shader_modules: CompiledShad
}],
});

let queries = device.create_query_set(&wgpu::QuerySetDescriptor {
label: None,
count: 2,
ty: wgpu::QueryType::Timestamp,
});
let queries = if timestamping {
Some(device.create_query_set(&wgpu::QuerySetDescriptor {
label: None,
count: 2,
ty: wgpu::QueryType::Timestamp,
}))
} else {
None
};

let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
Expand All @@ -149,9 +176,17 @@ async fn start_internal(options: &Options, compiled_shader_modules: CompiledShad
let mut cpass = encoder.begin_compute_pass(&Default::default());
cpass.set_bind_group(0, &bind_group, &[]);
cpass.set_pipeline(&compute_pipeline);
cpass.write_timestamp(&queries, 0);
if timestamping {
if let Some(queries) = queries.as_ref() {
cpass.write_timestamp(queries, 0);
}
}
cpass.dispatch_workgroups(src_range.len() as u32 / 64, 1, 1);
cpass.write_timestamp(&queries, 1);
if timestamping {
if let Some(queries) = queries.as_ref() {
cpass.write_timestamp(queries, 1);
}
}
}

encoder.copy_buffer_to_buffer(
Expand All @@ -161,38 +196,68 @@ async fn start_internal(options: &Options, compiled_shader_modules: CompiledShad
0,
src.len() as wgpu::BufferAddress,
);
encoder.resolve_query_set(&queries, 0..2, &timestamp_buffer, 0);
encoder.copy_buffer_to_buffer(
&timestamp_buffer,
0,
&timestamp_readback_buffer,
0,
timestamp_buffer.size(),
);

if timestamping {
if let (Some(queries), Some(timestamp_buffer), Some(timestamp_readback_buffer)) = (
queries.as_ref(),
timestamp_buffer.as_ref(),
timestamp_readback_buffer.as_ref(),
) {
encoder.resolve_query_set(queries, 0..2, timestamp_buffer, 0);
encoder.copy_buffer_to_buffer(
timestamp_buffer,
0,
timestamp_readback_buffer,
0,
timestamp_buffer.size(),
);
}
}

queue.submit(Some(encoder.finish()));
let buffer_slice = readback_buffer.slice(..);
let timestamp_slice = timestamp_readback_buffer.slice(..);
timestamp_slice.map_async(wgpu::MapMode::Read, |r| r.unwrap());
if timestamping {
if let Some(timestamp_readback_buffer) = timestamp_readback_buffer.as_ref() {
let timestamp_slice = timestamp_readback_buffer.slice(..);
timestamp_slice.map_async(wgpu::MapMode::Read, |r| r.unwrap());
}
}
buffer_slice.map_async(wgpu::MapMode::Read, |r| r.unwrap());
// NOTE(eddyb) `poll` should return only after the above callbacks fire
// (see also https://github.com/gfx-rs/wgpu/pull/2698 for more details).
device.poll(wgpu::Maintain::Wait);

if timestamping {
if let (Some(timestamp_readback_buffer), Some(timestamp_period)) =
(timestamp_readback_buffer.as_ref(), timestamp_period)
{
{
let timing_data = timestamp_readback_buffer.slice(..).get_mapped_range();
let timings = timing_data
.chunks_exact(8)
.map(|b| u64::from_ne_bytes(b.try_into().unwrap()))
.collect::<Vec<_>>();

println!(
"Took: {:?}",
Duration::from_nanos(
((timings[1] - timings[0]) as f64 * f64::from(timestamp_period)) as u64
)
);
drop(timing_data);
timestamp_readback_buffer.unmap();
}
}
}

let data = buffer_slice.get_mapped_range();
let timing_data = timestamp_slice.get_mapped_range();
let result = data
.chunks_exact(4)
.map(|b| u32::from_ne_bytes(b.try_into().unwrap()))
.collect::<Vec<_>>();
let timings = timing_data
.chunks_exact(8)
.map(|b| u64::from_ne_bytes(b.try_into().unwrap()))
.collect::<Vec<_>>();
drop(data);
readback_buffer.unmap();
drop(timing_data);
timestamp_readback_buffer.unmap();

let mut max = 0;
for (src, out) in src_range.zip(result.iter().copied()) {
if out == u32::MAX {
Expand All @@ -204,10 +269,4 @@ async fn start_internal(options: &Options, compiled_shader_modules: CompiledShad
println!("{src}: {out}");
}
}
println!(
"Took: {:?}",
Duration::from_nanos(
((timings[1] - timings[0]) as f64 * f64::from(timestamp_period)) as u64
)
);
}