Skip to content

Commit

Permalink
ray_tracing_pipeline: Add hit group
Browse files Browse the repository at this point in the history
  • Loading branch information
bwrsandman committed Apr 28, 2020
1 parent b6f7878 commit 255018d
Show file tree
Hide file tree
Showing 4 changed files with 722 additions and 117 deletions.
37 changes: 36 additions & 1 deletion examples/src/bin/nv-ray-tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,38 @@ void main() {
}
let ms = ms::Shader::load(device.clone()).unwrap();

mod cs {
vulkano_shaders::shader! {
ty: "closest_hit",
src: "#version 460 core
#extension GL_NV_ray_tracing : enable
layout(location = 0) rayPayloadInNV vec4 payload;
void main() {
payload = vec4(1, 1, 0, 1);
}
"
}
}
let cs = cs::Shader::load(device.clone()).unwrap();

mod is {
vulkano_shaders::shader! {
ty: "intersection",
src: "#version 460 core
#extension GL_NV_ray_tracing : enable
hitAttributeNV bool _unused_but_required;
void main() {
reportIntersectionNV(0.01, 0);
}
"
}
}
let is = is::Shader::load(device.clone()).unwrap();

// We set a limit to the recursion of a ray so that the shader does not run infinitely
let max_recursion_depth = 5;

Expand All @@ -202,6 +234,7 @@ void main() {
// and to store the result of their path tracing
.raygen_shader(rs.main_entry_point(), ())
.miss_shader(ms.main_entry_point(), ())
.group(RayTracingPipeline::group().closest_hit_shader(cs.main_entry_point(), ()).intersection_shader(is.main_entry_point(), ()))
.build(device.clone())
.unwrap(),
);
Expand Down Expand Up @@ -254,7 +287,9 @@ void main() {
)
.unwrap();
let (hit_shader_binding_table, hit_buffer_future) = ImmutableBuffer::from_iter(
(0..0).map(|_| 5u8),
group_handles[2 * group_handle_size..3 * group_handle_size]
.iter()
.copied(),
BufferUsage::ray_tracing(),
queue.clone(),
)
Expand Down
72 changes: 50 additions & 22 deletions vulkano/src/instance/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,17 +395,45 @@ impl Instance {
};

vk.GetPhysicalDeviceProperties2KHR(device, &mut output);
(output.properties, PhysicalDeviceRayTracingProperties {
shaderGroupHandleSize: max(nv_rt_output.shaderGroupHandleSize, khr_rt_output.shaderGroupHandleSize),
maxRecursionDepth: max(nv_rt_output.maxRecursionDepth, khr_rt_output.maxRecursionDepth),
maxShaderGroupStride: max(nv_rt_output.maxShaderGroupStride, khr_rt_output.maxShaderGroupStride),
shaderGroupBaseAlignment: max(nv_rt_output.shaderGroupBaseAlignment, khr_rt_output.shaderGroupBaseAlignment),
maxGeometryCount: max(nv_rt_output.maxGeometryCount, khr_rt_output.maxGeometryCount),
maxInstanceCount: max(nv_rt_output.maxInstanceCount, khr_rt_output.maxInstanceCount),
maxPrimitiveCount: max(nv_rt_output.maxTriangleCount, khr_rt_output.maxPrimitiveCount),
maxDescriptorSetAccelerationStructures: max(nv_rt_output.maxDescriptorSetAccelerationStructures, khr_rt_output.maxDescriptorSetAccelerationStructures),
shaderGroupHandleCaptureReplaySize: khr_rt_output.shaderGroupHandleCaptureReplaySize,
})
(
output.properties,
PhysicalDeviceRayTracingProperties {
shader_group_handle_size: max(
nv_rt_output.shaderGroupHandleSize,
khr_rt_output.shaderGroupHandleSize,
),
max_recursion_depth: max(
nv_rt_output.maxRecursionDepth,
khr_rt_output.maxRecursionDepth,
),
max_shader_group_stride: max(
nv_rt_output.maxShaderGroupStride,
khr_rt_output.maxShaderGroupStride,
),
shader_group_base_alignment: max(
nv_rt_output.shaderGroupBaseAlignment,
khr_rt_output.shaderGroupBaseAlignment,
),
max_geometry_count: max(
nv_rt_output.maxGeometryCount,
khr_rt_output.maxGeometryCount,
),
max_instance_count: max(
nv_rt_output.maxInstanceCount,
khr_rt_output.maxInstanceCount,
),
max_primitive_count: max(
nv_rt_output.maxTriangleCount,
khr_rt_output.maxPrimitiveCount,
),
max_descriptor_set_acceleration_structures: max(
nv_rt_output.maxDescriptorSetAccelerationStructures,
khr_rt_output.maxDescriptorSetAccelerationStructures,
),
shader_group_handle_capture_replay_size: khr_rt_output
.shaderGroupHandleCaptureReplaySize,
},
)
};

let queue_families = unsafe {
Expand Down Expand Up @@ -727,15 +755,15 @@ impl From<Error> for InstanceCreationError {

#[derive(Default)]
struct PhysicalDeviceRayTracingProperties {
shaderGroupHandleSize: u32,
maxRecursionDepth: u32,
maxShaderGroupStride: u32,
shaderGroupBaseAlignment: u32,
maxGeometryCount: u64,
maxInstanceCount: u64,
maxPrimitiveCount: u64,
maxDescriptorSetAccelerationStructures: u32,
shaderGroupHandleCaptureReplaySize: u32,
shader_group_handle_size: u32,
max_recursion_depth: u32,
max_shader_group_stride: u32,
shader_group_base_alignment: u32,
max_geometry_count: u64,
max_instance_count: u64,
max_primitive_count: u64,
max_descriptor_set_acceleration_structures: u32,
shader_group_handle_capture_replay_size: u32,
}

struct PhysicalDeviceInfos {
Expand Down Expand Up @@ -1016,13 +1044,13 @@ impl<'a> PhysicalDevice<'a> {
/// Returns the size of a shader group handle
#[inline]
pub fn shader_group_handle_size(&self) -> u32 {
self.infos().properties_ray_tracing.shaderGroupHandleSize
self.infos().properties_ray_tracing.shader_group_handle_size
}

/// Returns the maximum ray tracing recursion depth for a trace call
#[inline]
pub fn max_recursion_depth(&self) -> u32 {
self.infos().properties_ray_tracing.maxRecursionDepth
self.infos().properties_ray_tracing.max_recursion_depth
}

// Internal function to make it easier to get the infos of this device.
Expand Down
Loading

0 comments on commit 255018d

Please sign in to comment.