From 255018d613e60360b7b88ff70eea27ea7b634195 Mon Sep 17 00:00:00 2001 From: Sandy Carter Date: Sat, 25 Apr 2020 21:59:11 +0200 Subject: [PATCH] ray_tracing_pipeline: Add hit group --- examples/src/bin/nv-ray-tracing.rs | 37 +- vulkano/src/instance/instance.rs | 72 +- .../pipeline/ray_tracing_pipeline/builder.rs | 692 +++++++++++++++--- .../src/pipeline/ray_tracing_pipeline/mod.rs | 38 +- 4 files changed, 722 insertions(+), 117 deletions(-) diff --git a/examples/src/bin/nv-ray-tracing.rs b/examples/src/bin/nv-ray-tracing.rs index 74df758b51..b63e25c4e6 100644 --- a/examples/src/bin/nv-ray-tracing.rs +++ b/examples/src/bin/nv-ray-tracing.rs @@ -193,6 +193,38 @@ void main() { } let ms = ms::Shader::load(device.clone()).unwrap(); + mod cs { + vulkano_shaders::shader! { + ty: "closest_hit", + src: "#version 460 core +#extension GL_NV_ray_tracing : enable + +layout(location = 0) rayPayloadInNV vec4 payload; + +void main() { + payload = vec4(1, 1, 0, 1); +} +" + } + } + let cs = cs::Shader::load(device.clone()).unwrap(); + + mod is { + vulkano_shaders::shader! { + ty: "intersection", + src: "#version 460 core +#extension GL_NV_ray_tracing : enable + +hitAttributeNV bool _unused_but_required; + +void main() { + reportIntersectionNV(0.01, 0); +} +" + } + } + let is = is::Shader::load(device.clone()).unwrap(); + // We set a limit to the recursion of a ray so that the shader does not run infinitely let max_recursion_depth = 5; @@ -202,6 +234,7 @@ void main() { // and to store the result of their path tracing .raygen_shader(rs.main_entry_point(), ()) .miss_shader(ms.main_entry_point(), ()) + .group(RayTracingPipeline::group().closest_hit_shader(cs.main_entry_point(), ()).intersection_shader(is.main_entry_point(), ())) .build(device.clone()) .unwrap(), ); @@ -254,7 +287,9 @@ void main() { ) .unwrap(); let (hit_shader_binding_table, hit_buffer_future) = ImmutableBuffer::from_iter( - (0..0).map(|_| 5u8), + group_handles[2 * group_handle_size..3 * group_handle_size] + .iter() + .copied(), BufferUsage::ray_tracing(), queue.clone(), ) diff --git a/vulkano/src/instance/instance.rs b/vulkano/src/instance/instance.rs index 92fa5292a2..e8f1c509c4 100644 --- a/vulkano/src/instance/instance.rs +++ b/vulkano/src/instance/instance.rs @@ -395,17 +395,45 @@ impl Instance { }; vk.GetPhysicalDeviceProperties2KHR(device, &mut output); - (output.properties, PhysicalDeviceRayTracingProperties { - shaderGroupHandleSize: max(nv_rt_output.shaderGroupHandleSize, khr_rt_output.shaderGroupHandleSize), - maxRecursionDepth: max(nv_rt_output.maxRecursionDepth, khr_rt_output.maxRecursionDepth), - maxShaderGroupStride: max(nv_rt_output.maxShaderGroupStride, khr_rt_output.maxShaderGroupStride), - shaderGroupBaseAlignment: max(nv_rt_output.shaderGroupBaseAlignment, khr_rt_output.shaderGroupBaseAlignment), - maxGeometryCount: max(nv_rt_output.maxGeometryCount, khr_rt_output.maxGeometryCount), - maxInstanceCount: max(nv_rt_output.maxInstanceCount, khr_rt_output.maxInstanceCount), - maxPrimitiveCount: max(nv_rt_output.maxTriangleCount, khr_rt_output.maxPrimitiveCount), - maxDescriptorSetAccelerationStructures: max(nv_rt_output.maxDescriptorSetAccelerationStructures, khr_rt_output.maxDescriptorSetAccelerationStructures), - shaderGroupHandleCaptureReplaySize: khr_rt_output.shaderGroupHandleCaptureReplaySize, - }) + ( + output.properties, + PhysicalDeviceRayTracingProperties { + shader_group_handle_size: max( + nv_rt_output.shaderGroupHandleSize, + khr_rt_output.shaderGroupHandleSize, + ), + max_recursion_depth: max( + nv_rt_output.maxRecursionDepth, + khr_rt_output.maxRecursionDepth, + ), + max_shader_group_stride: max( + nv_rt_output.maxShaderGroupStride, + khr_rt_output.maxShaderGroupStride, + ), + shader_group_base_alignment: max( + nv_rt_output.shaderGroupBaseAlignment, + khr_rt_output.shaderGroupBaseAlignment, + ), + max_geometry_count: max( + nv_rt_output.maxGeometryCount, + khr_rt_output.maxGeometryCount, + ), + max_instance_count: max( + nv_rt_output.maxInstanceCount, + khr_rt_output.maxInstanceCount, + ), + max_primitive_count: max( + nv_rt_output.maxTriangleCount, + khr_rt_output.maxPrimitiveCount, + ), + max_descriptor_set_acceleration_structures: max( + nv_rt_output.maxDescriptorSetAccelerationStructures, + khr_rt_output.maxDescriptorSetAccelerationStructures, + ), + shader_group_handle_capture_replay_size: khr_rt_output + .shaderGroupHandleCaptureReplaySize, + }, + ) }; let queue_families = unsafe { @@ -727,15 +755,15 @@ impl From for InstanceCreationError { #[derive(Default)] struct PhysicalDeviceRayTracingProperties { - shaderGroupHandleSize: u32, - maxRecursionDepth: u32, - maxShaderGroupStride: u32, - shaderGroupBaseAlignment: u32, - maxGeometryCount: u64, - maxInstanceCount: u64, - maxPrimitiveCount: u64, - maxDescriptorSetAccelerationStructures: u32, - shaderGroupHandleCaptureReplaySize: u32, + shader_group_handle_size: u32, + max_recursion_depth: u32, + max_shader_group_stride: u32, + shader_group_base_alignment: u32, + max_geometry_count: u64, + max_instance_count: u64, + max_primitive_count: u64, + max_descriptor_set_acceleration_structures: u32, + shader_group_handle_capture_replay_size: u32, } struct PhysicalDeviceInfos { @@ -1016,13 +1044,13 @@ impl<'a> PhysicalDevice<'a> { /// Returns the size of a shader group handle #[inline] pub fn shader_group_handle_size(&self) -> u32 { - self.infos().properties_ray_tracing.shaderGroupHandleSize + self.infos().properties_ray_tracing.shader_group_handle_size } /// Returns the maximum ray tracing recursion depth for a trace call #[inline] pub fn max_recursion_depth(&self) -> u32 { - self.infos().properties_ray_tracing.maxRecursionDepth + self.infos().properties_ray_tracing.max_recursion_depth } // Internal function to make it easier to get the infos of this device. diff --git a/vulkano/src/pipeline/ray_tracing_pipeline/builder.rs b/vulkano/src/pipeline/ray_tracing_pipeline/builder.rs index 5b5b1b77da..ba9f9df694 100644 --- a/vulkano/src/pipeline/ray_tracing_pipeline/builder.rs +++ b/vulkano/src/pipeline/ray_tracing_pipeline/builder.rs @@ -29,23 +29,140 @@ use check_errors; use vk; use VulkanObject; +pub struct RayTracingPipelineGroupBuilder { + pub closest_hit_shader: Option<(Cs, Css)>, + pub any_hit_shader: Option<(As, Ass)>, + pub intersection_shader: Option<(Is, Iss)>, +} + +impl + RayTracingPipelineGroupBuilder< + EmptyEntryPointDummy, + (), + EmptyEntryPointDummy, + (), + EmptyEntryPointDummy, + (), + > +{ + pub(super) fn new() -> Self { + RayTracingPipelineGroupBuilder { + closest_hit_shader: None, + any_hit_shader: None, + intersection_shader: None, + } + } +} + +impl + RayTracingPipelineGroupBuilder +where + Cs1: RayTracingEntryPointAbstract, + Css1: SpecializationConstants, + Cs1::PipelineLayout: Clone + 'static + Send + Sync, // TODO: shouldn't be required + As1: RayTracingEntryPointAbstract, + Ass1: SpecializationConstants, + As1::PipelineLayout: Clone + 'static + Send + Sync, // TODO: shouldn't be required + Is1: RayTracingEntryPointAbstract, + Iss1: SpecializationConstants, + Is1::PipelineLayout: Clone + 'static + Send + Sync, // TODO: shouldn't be required +{ + /// Adds a closest hit shader to group. + // TODO: correct specialization constants + #[inline] + pub fn closest_hit_shader( + self, + shader: Cs2, + specialization_constants: Css2, + ) -> RayTracingPipelineGroupBuilder + where + Cs2: RayTracingEntryPointAbstract, + Css2: SpecializationConstants, + { + RayTracingPipelineGroupBuilder { + closest_hit_shader: Some((shader, specialization_constants)), + any_hit_shader: self.any_hit_shader, + intersection_shader: self.intersection_shader, + } + } + + /// Adds a any hit shader to group. + // TODO: correct specialization constants + #[inline] + pub fn any_hit_shader( + self, + shader: As2, + specialization_constants: Ass2, + ) -> RayTracingPipelineGroupBuilder + where + As2: RayTracingEntryPointAbstract, + Ass2: SpecializationConstants, + { + RayTracingPipelineGroupBuilder { + closest_hit_shader: self.closest_hit_shader, + any_hit_shader: Some((shader, specialization_constants)), + intersection_shader: self.intersection_shader, + } + } + + /// Adds an intersection shader to group. + // TODO: correct specialization constants + #[inline] + pub fn intersection_shader( + self, + shader: Is2, + specialization_constants: Iss2, + ) -> RayTracingPipelineGroupBuilder + where + Is2: RayTracingEntryPointAbstract, + Iss2: SpecializationConstants, + { + RayTracingPipelineGroupBuilder { + closest_hit_shader: self.closest_hit_shader, + any_hit_shader: self.any_hit_shader, + intersection_shader: Some((shader, specialization_constants)), + } + } +} + /// Prototype for a `RayTracingPipeline`. // TODO: we can optimize this by filling directly the raw vk structs -pub struct RayTracingPipelineBuilder { +pub struct RayTracingPipelineBuilder { nv_extension: bool, + // TODO: Should be a list raygen_shader: Option<(Rs, Rss)>, // TODO: Should be a list miss_shader: Option<(Ms, Mss)>, + // TODO: Should be a list + group: RayTracingPipelineGroupBuilder, max_recursion_depth: u32, } -impl RayTracingPipelineBuilder { +impl + RayTracingPipelineBuilder< + EmptyEntryPointDummy, + (), + EmptyEntryPointDummy, + (), + EmptyEntryPointDummy, + (), + EmptyEntryPointDummy, + (), + EmptyEntryPointDummy, + (), + > +{ /// Builds a new empty builder using the `nv_ray_tracing` extension. pub(super) fn nv(max_recursion_depth: u32) -> Self { RayTracingPipelineBuilder { nv_extension: true, raygen_shader: None, miss_shader: None, + group: RayTracingPipelineGroupBuilder { + closest_hit_shader: None, + any_hit_shader: None, + intersection_shader: None, + }, max_recursion_depth, } } @@ -56,12 +173,18 @@ impl RayTracingPipelineBuilder RayTracingPipelineBuilder +impl + RayTracingPipelineBuilder where Rs: RayTracingEntryPointAbstract, Rss: SpecializationConstants, @@ -69,6 +192,15 @@ where Ms: RayTracingEntryPointAbstract, Mss: SpecializationConstants, Ms::PipelineLayout: Clone + 'static + Send + Sync, // TODO: shouldn't be required + Cs: RayTracingEntryPointAbstract, + Css: SpecializationConstants, + Cs::PipelineLayout: Clone + 'static + Send + Sync, // TODO: shouldn't be required + As: RayTracingEntryPointAbstract, + Ass: SpecializationConstants, + As::PipelineLayout: Clone + 'static + Send + Sync, // TODO: shouldn't be required + Is: RayTracingEntryPointAbstract, + Iss: SpecializationConstants, + Is::PipelineLayout: Clone + 'static + Send + Sync, // TODO: shouldn't be required { /// Builds the ray tracing pipeline, using an inferred a pipeline layout. // TODO: replace Box with a PipelineUnion struct without template params @@ -113,23 +245,191 @@ where let layout = shader.layout().clone(); if let Some((ref shader, _)) = self.miss_shader { let layout = layout.union(shader.layout().clone()); - pipeline_layout = Box::new( - PipelineLayoutDescTweaks::new( - layout, - dynamic_buffers.into_iter().cloned(), - ) - .build(device.clone()) - .unwrap(), - ) as Box<_>; // TODO: error + if let Some((ref shader, _)) = self.group.closest_hit_shader { + let layout = layout.union(shader.layout().clone()); + if let Some((ref shader, _)) = self.group.intersection_shader { + let layout = layout.union(shader.layout().clone()); + if let Some((ref shader, _)) = self.group.any_hit_shader { + let layout = layout.union(shader.layout().clone()); + pipeline_layout = Box::new( + PipelineLayoutDescTweaks::new( + layout, + dynamic_buffers.into_iter().cloned(), + ) + .build(device.clone()) + .unwrap(), + ) as Box<_>; // TODO: error + } else { + pipeline_layout = Box::new( + PipelineLayoutDescTweaks::new( + layout, + dynamic_buffers.into_iter().cloned(), + ) + .build(device.clone()) + .unwrap(), + ) as Box<_>; // TODO: error + } + } else { + if let Some((ref shader, _)) = self.group.any_hit_shader { + let layout = layout.union(shader.layout().clone()); + pipeline_layout = Box::new( + PipelineLayoutDescTweaks::new( + layout, + dynamic_buffers.into_iter().cloned(), + ) + .build(device.clone()) + .unwrap(), + ) as Box<_>; // TODO: error + } else { + pipeline_layout = Box::new( + PipelineLayoutDescTweaks::new( + layout, + dynamic_buffers.into_iter().cloned(), + ) + .build(device.clone()) + .unwrap(), + ) as Box<_>; // TODO: error + } + } + } else { + if let Some((ref shader, _)) = self.group.intersection_shader { + let layout = layout.union(shader.layout().clone()); + if let Some((ref shader, _)) = self.group.any_hit_shader { + let layout = layout.union(shader.layout().clone()); + pipeline_layout = Box::new( + PipelineLayoutDescTweaks::new( + layout, + dynamic_buffers.into_iter().cloned(), + ) + .build(device.clone()) + .unwrap(), + ) as Box<_>; // TODO: error + } else { + pipeline_layout = Box::new( + PipelineLayoutDescTweaks::new( + layout, + dynamic_buffers.into_iter().cloned(), + ) + .build(device.clone()) + .unwrap(), + ) as Box<_>; // TODO: error + } + } else { + if let Some((ref shader, _)) = self.group.any_hit_shader { + let layout = layout.union(shader.layout().clone()); + pipeline_layout = Box::new( + PipelineLayoutDescTweaks::new( + layout, + dynamic_buffers.into_iter().cloned(), + ) + .build(device.clone()) + .unwrap(), + ) as Box<_>; // TODO: error + } else { + pipeline_layout = Box::new( + PipelineLayoutDescTweaks::new( + layout, + dynamic_buffers.into_iter().cloned(), + ) + .build(device.clone()) + .unwrap(), + ) as Box<_>; // TODO: error + } + } + } } else { - pipeline_layout = Box::new( - PipelineLayoutDescTweaks::new( - layout, - dynamic_buffers.into_iter().cloned(), - ) - .build(device.clone()) - .unwrap(), - ) as Box<_>; // TODO: error + if let Some((ref shader, _)) = self.group.closest_hit_shader { + let layout = layout.union(shader.layout().clone()); + if let Some((ref shader, _)) = self.group.intersection_shader { + let layout = layout.union(shader.layout().clone()); + if let Some((ref shader, _)) = self.group.any_hit_shader { + let layout = layout.union(shader.layout().clone()); + pipeline_layout = Box::new( + PipelineLayoutDescTweaks::new( + layout, + dynamic_buffers.into_iter().cloned(), + ) + .build(device.clone()) + .unwrap(), + ) as Box<_>; // TODO: error + } else { + pipeline_layout = Box::new( + PipelineLayoutDescTweaks::new( + layout, + dynamic_buffers.into_iter().cloned(), + ) + .build(device.clone()) + .unwrap(), + ) as Box<_>; // TODO: error + } + } else { + if let Some((ref shader, _)) = self.group.any_hit_shader { + let layout = layout.union(shader.layout().clone()); + pipeline_layout = Box::new( + PipelineLayoutDescTweaks::new( + layout, + dynamic_buffers.into_iter().cloned(), + ) + .build(device.clone()) + .unwrap(), + ) as Box<_>; // TODO: error + } else { + pipeline_layout = Box::new( + PipelineLayoutDescTweaks::new( + layout, + dynamic_buffers.into_iter().cloned(), + ) + .build(device.clone()) + .unwrap(), + ) as Box<_>; // TODO: error + } + } + } else { + if let Some((ref shader, _)) = self.group.intersection_shader { + let layout = layout.union(shader.layout().clone()); + if let Some((ref shader, _)) = self.group.any_hit_shader { + let layout = layout.union(shader.layout().clone()); + pipeline_layout = Box::new( + PipelineLayoutDescTweaks::new( + layout, + dynamic_buffers.into_iter().cloned(), + ) + .build(device.clone()) + .unwrap(), + ) as Box<_>; // TODO: error + } else { + pipeline_layout = Box::new( + PipelineLayoutDescTweaks::new( + layout, + dynamic_buffers.into_iter().cloned(), + ) + .build(device.clone()) + .unwrap(), + ) as Box<_>; // TODO: error + } + } else { + if let Some((ref shader, _)) = self.group.any_hit_shader { + let layout = layout.union(shader.layout().clone()); + pipeline_layout = Box::new( + PipelineLayoutDescTweaks::new( + layout, + dynamic_buffers.into_iter().cloned(), + ) + .build(device.clone()) + .unwrap(), + ) as Box<_>; // TODO: error + } else { + pipeline_layout = Box::new( + PipelineLayoutDescTweaks::new( + layout, + dynamic_buffers.into_iter().cloned(), + ) + .build(device.clone()) + .unwrap(), + ) as Box<_>; // TODO: error + } + } + } } } else { return Err(RayTracingPipelineCreationError::NoRaygenShader); @@ -170,60 +470,129 @@ where let vk = device.pointers(); // Creating the specialization constants of the various stages. - let raygen_shader_specialization = { - let spec_descriptors = Rss::descriptors(); - let constants = &self.raygen_shader.as_ref().unwrap().1; - vk::SpecializationInfo { - mapEntryCount: spec_descriptors.len() as u32, - pMapEntries: spec_descriptors.as_ptr() as *const _, - dataSize: mem::size_of_val(constants), - pData: constants as *const Rss as *const _, + let raygen_stage = { + let raygen_shader_specialization = { + let spec_descriptors = Rss::descriptors(); + let constants = &self.raygen_shader.as_ref().unwrap().1; + vk::SpecializationInfo { + mapEntryCount: spec_descriptors.len() as u32, + pMapEntries: spec_descriptors.as_ptr() as *const _, + dataSize: mem::size_of_val(constants), + pData: constants as *const Rss as *const _, + } + }; + vk::PipelineShaderStageCreateInfo { + sType: vk::STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, + pNext: ptr::null(), + flags: 0, // reserved + stage: vk::SHADER_STAGE_RAYGEN_BIT_KHR, + module: self + .raygen_shader + .as_ref() + .unwrap() + .0 + .module() + .internal_object(), + pName: self.raygen_shader.as_ref().unwrap().0.name().as_ptr(), + pSpecializationInfo: &raygen_shader_specialization as *const _, } }; - let raygen_stage = vk::PipelineShaderStageCreateInfo { - sType: vk::STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, - pNext: ptr::null(), - flags: 0, // reserved - stage: vk::SHADER_STAGE_RAYGEN_BIT_KHR, - module: self - .raygen_shader - .as_ref() - .unwrap() - .0 - .module() - .internal_object(), - pName: self.raygen_shader.as_ref().unwrap().0.name().as_ptr(), - pSpecializationInfo: &raygen_shader_specialization as *const _, + let miss_stage = match self.miss_shader { + Some((shader, constants)) => { + let specialization = { + let spec_descriptors = Mss::descriptors(); + vk::SpecializationInfo { + mapEntryCount: spec_descriptors.len() as u32, + pMapEntries: spec_descriptors.as_ptr() as *const _, + dataSize: mem::size_of_val(&constants), + pData: &constants as *const Mss as *const _, + } + }; + Some(vk::PipelineShaderStageCreateInfo { + sType: vk::STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, + pNext: ptr::null(), + flags: 0, // reserved + stage: vk::SHADER_STAGE_MISS_BIT_KHR, + module: shader.module().internal_object(), + pName: shader.name().as_ptr(), + pSpecializationInfo: &specialization as *const _, + }) + } + None => None, }; - let miss_shader_specialization = { - let spec_descriptors = Mss::descriptors(); - let constants = &self.miss_shader.as_ref().unwrap().1; - vk::SpecializationInfo { - mapEntryCount: spec_descriptors.len() as u32, - pMapEntries: spec_descriptors.as_ptr() as *const _, - dataSize: mem::size_of_val(constants), - pData: constants as *const Mss as *const _, + let closest_hit_stage = match self.group.closest_hit_shader { + Some((shader, constants)) => { + let specialization = { + let spec_descriptors = Css::descriptors(); + vk::SpecializationInfo { + mapEntryCount: spec_descriptors.len() as u32, + pMapEntries: spec_descriptors.as_ptr() as *const _, + dataSize: mem::size_of_val(&constants), + pData: &constants as *const Css as *const _, + } + }; + Some(vk::PipelineShaderStageCreateInfo { + sType: vk::STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, + pNext: ptr::null(), + flags: 0, // reserved + stage: vk::SHADER_STAGE_CLOSEST_HIT_BIT_KHR, + module: shader.module().internal_object(), + pName: shader.name().as_ptr(), + pSpecializationInfo: &specialization as *const _, + }) } + None => None, }; - let miss_stage = vk::PipelineShaderStageCreateInfo { - sType: vk::STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, - pNext: ptr::null(), - flags: 0, // reserved - stage: vk::SHADER_STAGE_MISS_BIT_NV, - module: self - .miss_shader - .as_ref() - .unwrap() - .0 - .module() - .internal_object(), - pName: self.miss_shader.as_ref().unwrap().0.name().as_ptr(), - pSpecializationInfo: &miss_shader_specialization as *const _, + let any_hit_stage = match self.group.any_hit_shader { + Some((shader, constants)) => { + let specialization = { + let spec_descriptors = Ass::descriptors(); + vk::SpecializationInfo { + mapEntryCount: spec_descriptors.len() as u32, + pMapEntries: spec_descriptors.as_ptr() as *const _, + dataSize: mem::size_of_val(&constants), + pData: &constants as *const Ass as *const _, + } + }; + Some(vk::PipelineShaderStageCreateInfo { + sType: vk::STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, + pNext: ptr::null(), + flags: 0, // reserved + stage: vk::SHADER_STAGE_ANY_HIT_BIT_KHR, + module: shader.module().internal_object(), + pName: shader.name().as_ptr(), + pSpecializationInfo: &specialization as *const _, + }) + } + None => None, + }; + let intersection_stage = match self.group.intersection_shader { + Some((shader, constants)) => { + let specialization = { + let spec_descriptors = Iss::descriptors(); + vk::SpecializationInfo { + mapEntryCount: spec_descriptors.len() as u32, + pMapEntries: spec_descriptors.as_ptr() as *const _, + dataSize: mem::size_of_val(&constants), + pData: &constants as *const Iss as *const _, + } + }; + Some(vk::PipelineShaderStageCreateInfo { + sType: vk::STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, + pNext: ptr::null(), + flags: 0, // reserved + stage: vk::SHADER_STAGE_INTERSECTION_BIT_KHR, + module: shader.module().internal_object(), + pName: shader.name().as_ptr(), + pSpecializationInfo: &specialization as *const _, + }) + } + None => None, }; let (pipeline, group_count) = if self.nv_extension { - let mut stages = SmallVec::<[_; 1]>::new(); - let mut groups = SmallVec::<[_; 1]>::new(); + let mut stages = SmallVec::<[_; 5]>::new(); + let mut groups = SmallVec::<[_; 5]>::new(); // Raygen groups.push(vk::RayTracingShaderGroupCreateInfoNV { @@ -238,16 +607,66 @@ where stages.push(raygen_stage); // Miss - groups.push(vk::RayTracingShaderGroupCreateInfoNV { - sType: vk::STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_NV, - pNext: ptr::null(), - type_: vk::RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_NV, - generalShader: stages.len() as u32, - closestHitShader: vk::SHADER_UNUSED, - anyHitShader: vk::SHADER_UNUSED, - intersectionShader: vk::SHADER_UNUSED, - }); - stages.push(miss_stage); + match miss_stage { + Some(miss_stage) => { + groups.push(vk::RayTracingShaderGroupCreateInfoNV { + sType: vk::STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_NV, + pNext: ptr::null(), + type_: vk::RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_NV, + generalShader: stages.len() as u32, + closestHitShader: vk::SHADER_UNUSED, + anyHitShader: vk::SHADER_UNUSED, + intersectionShader: vk::SHADER_UNUSED, + }); + stages.push(miss_stage); + } + None => {} + } + + // Groups + if closest_hit_stage.is_some() + || any_hit_stage.is_some() + || intersection_stage.is_some() + { + let mut info = vk::RayTracingShaderGroupCreateInfoNV { + sType: vk::STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_NV, + pNext: ptr::null(), + type_: vk::RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_NV, + generalShader: vk::SHADER_UNUSED, + closestHitShader: vk::SHADER_UNUSED, + anyHitShader: vk::SHADER_UNUSED, + intersectionShader: vk::SHADER_UNUSED, + }; + + match closest_hit_stage { + Some(stage) => { + info.closestHitShader = stages.len() as u32; + stages.push(stage); + } + None => {} + } + + match any_hit_stage { + Some(stage) => { + info.anyHitShader = stages.len() as u32; + stages.push(stage); + } + None => {} + } + + match intersection_stage { + Some(stage) => { + // If there is an intersection stage, then the type is procedural, not triangles + info.type_ = vk::RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_NV; + info.intersectionShader = stages.len() as u32; + stages.push(stage); + } + None => {} + } + + groups.push(info); + } + unsafe { let infos = vk::RayTracingPipelineCreateInfoNV { sType: vk::STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_NV, @@ -275,8 +694,8 @@ where (output.assume_init(), groups.len() as u32) } } else { - let mut stages = SmallVec::<[_; 1]>::new(); - let mut groups = SmallVec::<[_; 1]>::new(); + let mut stages = SmallVec::<[_; 5]>::new(); + let mut groups = SmallVec::<[_; 5]>::new(); // Raygen groups.push(vk::RayTracingShaderGroupCreateInfoKHR { @@ -292,17 +711,67 @@ where stages.push(raygen_stage); // Miss - groups.push(vk::RayTracingShaderGroupCreateInfoKHR { - sType: vk::STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR, - pNext: ptr::null(), - type_: vk::RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR, - generalShader: stages.len() as u32, - closestHitShader: vk::SHADER_UNUSED, - anyHitShader: vk::SHADER_UNUSED, - intersectionShader: vk::SHADER_UNUSED, - pShaderGroupCaptureReplayHandle: ptr::null(), // TODO: - }); - stages.push(miss_stage); + match miss_stage { + Some(miss_stage) => { + groups.push(vk::RayTracingShaderGroupCreateInfoKHR { + sType: vk::STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR, + pNext: ptr::null(), + type_: vk::RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR, + generalShader: stages.len() as u32, + closestHitShader: vk::SHADER_UNUSED, + anyHitShader: vk::SHADER_UNUSED, + intersectionShader: vk::SHADER_UNUSED, + pShaderGroupCaptureReplayHandle: ptr::null(), // TODO: + }); + stages.push(miss_stage); + } + None => {} + } + + // Groups + if closest_hit_stage.is_some() + || any_hit_stage.is_some() + || intersection_stage.is_some() + { + let mut info = vk::RayTracingShaderGroupCreateInfoKHR { + sType: vk::STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR, + pNext: ptr::null(), + type_: vk::RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR, + generalShader: vk::SHADER_UNUSED, + closestHitShader: vk::SHADER_UNUSED, + anyHitShader: vk::SHADER_UNUSED, + intersectionShader: vk::SHADER_UNUSED, + pShaderGroupCaptureReplayHandle: ptr::null(), // TODO: + }; + + match closest_hit_stage { + Some(stage) => { + info.closestHitShader = stages.len() as u32; + stages.push(stage); + } + None => {} + } + + match any_hit_stage { + Some(stage) => { + info.anyHitShader = stages.len() as u32; + stages.push(stage); + } + None => {} + } + + match intersection_stage { + Some(stage) => { + // If there is an intersection stage, then the type is procedural, not triangles + info.type_ = vk::RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR; + info.intersectionShader = stages.len() as u32; + stages.push(stage); + } + None => {} + } + + groups.push(info); + } let library_info = vk::PipelineLibraryCreateInfoKHR { sType: vk::STRUCTURE_TYPE_PIPELINE_LIBRARY_CREATE_INFO_KHR, @@ -363,7 +832,9 @@ where // TODO: add build_with_cache method } -impl RayTracingPipelineBuilder { +impl + RayTracingPipelineBuilder +{ // TODO: add pipeline derivate system /// Adds a raygen shader group to use. @@ -374,7 +845,7 @@ impl RayTracingPipelineBuilder { self, shader: Rs2, specialization_constants: Rss2, - ) -> RayTracingPipelineBuilder + ) -> RayTracingPipelineBuilder where Rs2: RayTracingEntryPointAbstract, Rss2: SpecializationConstants, @@ -384,6 +855,7 @@ impl RayTracingPipelineBuilder { max_recursion_depth: self.max_recursion_depth, raygen_shader: Some((shader, specialization_constants)), miss_shader: self.miss_shader, + group: self.group, } } @@ -393,7 +865,7 @@ impl RayTracingPipelineBuilder { #[inline] pub fn miss_shader( self, shader: Ms2, specialization_constants: Mss2, - ) -> RayTracingPipelineBuilder + ) -> RayTracingPipelineBuilder where Ms2: RayTracingEntryPointAbstract, Mss2: SpecializationConstants, @@ -403,22 +875,58 @@ impl RayTracingPipelineBuilder { max_recursion_depth: self.max_recursion_depth, raygen_shader: self.raygen_shader, miss_shader: Some((shader, specialization_constants)), + group: self.group, + } + } + + /// Add a ray tracing group + #[inline] + pub fn group( + self, + group: RayTracingPipelineGroupBuilder, + ) -> RayTracingPipelineBuilder + where + Cs2: RayTracingEntryPointAbstract, + Css2: SpecializationConstants, + As2: RayTracingEntryPointAbstract, + Ass2: SpecializationConstants, + Is2: RayTracingEntryPointAbstract, + Iss2: SpecializationConstants, + { + RayTracingPipelineBuilder { + nv_extension: self.nv_extension, + max_recursion_depth: self.max_recursion_depth, + raygen_shader: self.raygen_shader, + miss_shader: self.miss_shader, + group, } } } -impl Clone for RayTracingPipelineBuilder +impl Clone + for RayTracingPipelineBuilder where Rs: Clone, Rss: Clone, Ms: Clone, Mss: Clone, + Cs: Clone, + Css: Clone, + As: Clone, + Ass: Clone, + Is: Clone, + Iss: Clone, { fn clone(&self) -> Self { RayTracingPipelineBuilder { nv_extension: self.nv_extension, raygen_shader: self.raygen_shader.clone(), miss_shader: self.miss_shader.clone(), + group: RayTracingPipelineGroupBuilder { + closest_hit_shader: self.group.closest_hit_shader.clone(), + any_hit_shader: self.group.any_hit_shader.clone(), + intersection_shader: self.group.intersection_shader.clone(), + }, max_recursion_depth: self.max_recursion_depth, } } diff --git a/vulkano/src/pipeline/ray_tracing_pipeline/mod.rs b/vulkano/src/pipeline/ray_tracing_pipeline/mod.rs index 410328aaa0..49cdd303ca 100644 --- a/vulkano/src/pipeline/ray_tracing_pipeline/mod.rs +++ b/vulkano/src/pipeline/ray_tracing_pipeline/mod.rs @@ -27,6 +27,7 @@ use VulkanObject; pub use self::builder::RayTracingPipelineBuilder; pub use self::creation_error::RayTracingPipelineCreationError; +use pipeline::ray_tracing_pipeline::builder::RayTracingPipelineGroupBuilder; mod builder; mod creation_error; @@ -57,7 +58,18 @@ impl RayTracingPipeline<()> { /// Returns a builder object that you can fill with the various parameters. pub fn nv<'a>( max_recursion_depth: u32, - ) -> RayTracingPipelineBuilder { + ) -> RayTracingPipelineBuilder< + EmptyEntryPointDummy, + (), + EmptyEntryPointDummy, + (), + EmptyEntryPointDummy, + (), + EmptyEntryPointDummy, + (), + EmptyEntryPointDummy, + (), + > { RayTracingPipelineBuilder::nv(max_recursion_depth) } @@ -66,9 +78,31 @@ impl RayTracingPipeline<()> { /// Returns a builder object that you can fill with the various parameters. pub fn khr<'a>( max_recursion_depth: u32, - ) -> RayTracingPipelineBuilder { + ) -> RayTracingPipelineBuilder< + EmptyEntryPointDummy, + (), + EmptyEntryPointDummy, + (), + EmptyEntryPointDummy, + (), + EmptyEntryPointDummy, + (), + EmptyEntryPointDummy, + (), + > { RayTracingPipelineBuilder::khr(max_recursion_depth) } + + pub fn group<'a>() -> RayTracingPipelineGroupBuilder< + EmptyEntryPointDummy, + (), + EmptyEntryPointDummy, + (), + EmptyEntryPointDummy, + (), + > { + RayTracingPipelineGroupBuilder::new() + } } impl RayTracingPipeline {