diff --git a/Cargo.lock b/Cargo.lock index 1664e7bd7a..c68bb49db4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1736,6 +1736,8 @@ dependencies = [ name = "triangle-util" version = "0.0.0" dependencies = [ + "ash", + "glam", "vulkano", "vulkano-shaders", "vulkano-util", diff --git a/examples/triangle-util/Cargo.toml b/examples/triangle-util/Cargo.toml index f5d7a10104..d4d99abf9b 100644 --- a/examples/triangle-util/Cargo.toml +++ b/examples/triangle-util/Cargo.toml @@ -21,3 +21,5 @@ vulkano-util = { workspace = true } # The Vulkan library doesn't provide any functionality to create and handle windows, as # this would be out of scope. In order to open a window, we are going to use the `winit` crate. winit = { workspace = true, default-features = true } +ash = { workspace = true } +glam = { workspace = true } diff --git a/examples/triangle-util/raytrace.rchit b/examples/triangle-util/raytrace.rchit new file mode 100644 index 0000000000..52c407b96a --- /dev/null +++ b/examples/triangle-util/raytrace.rchit @@ -0,0 +1,10 @@ +#version 460 +#extension GL_EXT_ray_tracing : require + +layout(location = 0) rayPayloadInEXT vec3 hitValue; +hitAttributeEXT vec2 attribs; + +void main() { + vec3 barycentrics = vec3(1.0 - attribs.x - attribs.y, attribs.x, attribs.y); + hitValue = barycentrics; +} diff --git a/examples/triangle-util/raytrace.rgen b/examples/triangle-util/raytrace.rgen new file mode 100644 index 0000000000..8a9416e201 --- /dev/null +++ b/examples/triangle-util/raytrace.rgen @@ -0,0 +1,43 @@ +#version 460 +#extension GL_EXT_ray_tracing : require + +struct Camera { + mat4 viewProj; // Camera view * projection + mat4 viewInverse; // Camera inverse view matrix + mat4 projInverse; // Camera inverse projection matrix +}; + +layout(location = 0) rayPayloadEXT vec3 hitValue; + +layout(set = 0, binding = 0) uniform accelerationStructureEXT topLevelAS; +layout(set = 0, binding = 1) uniform _Camera { Camera camera; }; +layout(set = 1, binding = 0, rgba32f) uniform image2D image; + +void main() { + const vec2 pixelCenter = vec2(gl_LaunchIDEXT.xy) + vec2(0.5); + const vec2 inUV = pixelCenter / vec2(gl_LaunchSizeEXT.xy); + vec2 d = inUV * 2.0 - 1.0; + + vec4 origin = camera.viewInverse * vec4(0, 0, 0, 1); + vec4 target = camera.projInverse * vec4(d.x, d.y, 1, 1); + vec4 direction = camera.viewInverse * vec4(normalize(target.xyz), 0); + + uint rayFlags = gl_RayFlagsOpaqueEXT; + float tMin = 0.001; + float tMax = 10000.0; + + traceRayEXT(topLevelAS, // acceleration structure + rayFlags, // rayFlags + 0xFF, // cullMask + 0, // sbtRecordOffset + 0, // sbtRecordStride + 0, // missIndex + origin.xyz, // ray origin + tMin, // ray min range + direction.xyz, // ray direction + tMax, // ray max range + 0 // payload (location = 0) + ); + + imageStore(image, ivec2(gl_LaunchIDEXT.xy), vec4(hitValue, 1.0)); +} diff --git a/examples/triangle-util/raytrace.rmiss b/examples/triangle-util/raytrace.rmiss new file mode 100644 index 0000000000..1c584d5420 --- /dev/null +++ b/examples/triangle-util/raytrace.rmiss @@ -0,0 +1,6 @@ +#version 460 +#extension GL_EXT_ray_tracing : require + +layout(location = 0) rayPayloadInEXT vec3 hitValue; + +void main() { hitValue = vec3(0.0, 0.0, 0.2); } diff --git a/vulkano/src/buffer/usage.rs b/vulkano/src/buffer/usage.rs index 7b57595561..72f9eb0a19 100644 --- a/vulkano/src/buffer/usage.rs +++ b/vulkano/src/buffer/usage.rs @@ -97,13 +97,13 @@ vulkan_bitflags! { RequiresAllOf([DeviceExtension(khr_acceleration_structure)]), ]), - /* TODO: enable + // TODO: document SHADER_BINDING_TABLE = SHADER_BINDING_TABLE_KHR RequiresOneOf([ RequiresAllOf([DeviceExtension(khr_ray_tracing_pipeline)]), RequiresAllOf([DeviceExtension(nv_ray_tracing)]), - ]),*/ + ]), /* TODO: enable // TODO: document diff --git a/vulkano/src/command_buffer/auto/builder.rs b/vulkano/src/command_buffer/auto/builder.rs index ca953227b9..0332ec2d80 100644 --- a/vulkano/src/command_buffer/auto/builder.rs +++ b/vulkano/src/command_buffer/auto/builder.rs @@ -29,6 +29,7 @@ use crate::{ vertex_input::VertexInputState, viewport::{Scissor, Viewport}, }, + ray_tracing::RayTracingPipeline, ComputePipeline, DynamicState, GraphicsPipeline, PipelineBindPoint, PipelineLayout, }, query::{QueryControlFlags, QueryPool, QueryType}, @@ -1292,6 +1293,7 @@ pub(in crate::command_buffer) struct CommandBufferBuilderState { pub(in crate::command_buffer) index_buffer: Option, pub(in crate::command_buffer) pipeline_compute: Option>, pub(in crate::command_buffer) pipeline_graphics: Option>, + pub(in crate::command_buffer) pipeline_ray_tracing: Option>, pub(in crate::command_buffer) vertex_buffers: HashMap>, pub(in crate::command_buffer) push_constants: RangeSet, pub(in crate::command_buffer) push_constants_pipeline_layout: Option>, diff --git a/vulkano/src/command_buffer/commands/bind_push.rs b/vulkano/src/command_buffer/commands/bind_push.rs index 37d54ac6df..b64d4b8a7a 100644 --- a/vulkano/src/command_buffer/commands/bind_push.rs +++ b/vulkano/src/command_buffer/commands/bind_push.rs @@ -10,8 +10,8 @@ use crate::{ device::{DeviceOwned, QueueFlags}, memory::is_aligned, pipeline::{ - graphics::vertex_input::VertexBuffersCollection, ComputePipeline, GraphicsPipeline, - PipelineBindPoint, PipelineLayout, + graphics::vertex_input::VertexBuffersCollection, ray_tracing::RayTracingPipeline, + ComputePipeline, GraphicsPipeline, PipelineBindPoint, PipelineLayout, }, DeviceSize, Requires, RequiresAllOf, RequiresOneOf, ValidationError, Version, VulkanObject, }; @@ -794,6 +794,9 @@ impl RecordingCommandBuffer { })); } } + PipelineBindPoint::RayTracing => { + // TODO: RayTracing: Validation + } } if first_set + descriptor_sets as u32 > pipeline_layout.set_layouts().len() as u32 { @@ -1018,6 +1021,26 @@ impl RecordingCommandBuffer { self } + pub unsafe fn bind_pipeline_ray_tracing(&mut self, pipeline: &RayTracingPipeline) -> &mut Self { + // TODO: RayTracing: Validation + self.bind_pipeline_ray_tracing_unchecked(pipeline) + } + + #[cfg_attr(not(feature = "document_unchecked"), doc(hidden))] + pub unsafe fn bind_pipeline_ray_tracing_unchecked( + &mut self, + pipeline: &RayTracingPipeline, + ) -> &mut Self { + let fns = self.device().fns(); + (fns.v1_0.cmd_bind_pipeline)( + self.handle(), + ash::vk::PipelineBindPoint::RAY_TRACING_KHR, + pipeline.handle(), + ); + + self + } + #[inline] pub unsafe fn bind_vertex_buffers( &mut self, @@ -1395,6 +1418,9 @@ impl RecordingCommandBuffer { })); } } + PipelineBindPoint::RayTracing => { + // TODO: RayTracing + } } // VUID-vkCmdPushDescriptorSetKHR-commonparent diff --git a/vulkano/src/command_buffer/commands/pipeline.rs b/vulkano/src/command_buffer/commands/pipeline.rs index e85bd48635..17a09cb3de 100644 --- a/vulkano/src/command_buffer/commands/pipeline.rs +++ b/vulkano/src/command_buffer/commands/pipeline.rs @@ -22,6 +22,7 @@ use crate::{ subpass::PipelineSubpassType, vertex_input::{RequiredVertexInputsVUIDs, VertexInputRate}, }, + ray_tracing::ShaderBindingTable, DynamicState, GraphicsPipeline, Pipeline, PipelineLayout, }, query::QueryType, @@ -1592,6 +1593,53 @@ impl AutoCommandBufferBuilder { self } + pub unsafe fn trace_rays( + &mut self, + shader_binding_table: ShaderBindingTable, + width: u32, + height: u32, + depth: u32, + ) -> Result<&mut Self, Box> { + // TODO: RayTrace: Validation + + Ok(self.trace_rays_unchecked(shader_binding_table, width, height, depth)) + } + + #[cfg_attr(not(feature = "document_unchecked"), doc(hidden))] + pub unsafe fn trace_rays_unchecked( + &mut self, + shader_binding_table: ShaderBindingTable, + width: u32, + height: u32, + depth: u32, + ) -> &mut Self { + // TODO: RayTracing: as_deref() + let pipeline = self.builder_state.pipeline_ray_tracing.as_deref().unwrap(); + + let mut used_resources = Vec::new(); + self.add_descriptor_sets_resources(&mut used_resources, pipeline); + self.add_shader_binding_table_buffer_resources( + &mut used_resources, + shader_binding_table.buffer(), + ); + + self.add_command("ray_trace", used_resources, move |out| { + out.trace_rays_unchecked(&shader_binding_table, width, height, depth); + }); + + self + } + + fn validate_trace_rays( + &self, + shader_binding_table: &ShaderBindingTable, + width: u32, + height: u32, + depth: u32, + ) -> Result<(), Box> { + todo!() + } + fn validate_pipeline_descriptor_sets( &self, vuid_type: VUIDType, @@ -3714,6 +3762,21 @@ impl AutoCommandBufferBuilder { }, )); } + + fn add_shader_binding_table_buffer_resources( + &self, + used_resources: &mut Vec<(ResourceUseRef2, Resource)>, + sbt_buffer: &Subbuffer<[u8]>, + ) { + used_resources.push(( + ResourceInCommand::ShaderBindingTableBuffer.into(), + Resource::Buffer { + buffer: sbt_buffer.clone(), + range: 0..sbt_buffer.size(), + memory_access: PipelineStageAccessFlags::RayTracingShader_ShaderBindingTableRead, + }, + )); + } } impl RecordingCommandBuffer { @@ -4947,6 +5010,47 @@ impl RecordingCommandBuffer { self } + + pub unsafe fn trace_rays( + &mut self, + shader_binding_table: &ShaderBindingTable, + width: u32, + height: u32, + depth: u32, + ) -> Result<&mut Self, Box> { + // self.validate_trace_ray()?; + // TODO: RayTracing: Validation + + Ok(self.trace_rays_unchecked(shader_binding_table, width, height, depth)) + } + + fn validate_trace_rays(&self) -> Result<(), Box> { + todo!() + } + + #[cfg_attr(not(feature = "document_unchecked"), doc(hidden))] + pub unsafe fn trace_rays_unchecked( + &mut self, + shader_binding_table: &ShaderBindingTable, + width: u32, + height: u32, + depth: u32, + ) -> &mut Self { + let fns = self.device().fns(); + + (fns.khr_ray_tracing_pipeline.cmd_trace_rays_khr)( + self.handle(), + shader_binding_table.raygen(), + shader_binding_table.miss(), + shader_binding_table.hit(), + shader_binding_table.callable(), + width, + height, + depth, + ); + + self + } } #[derive(Clone, Copy)] diff --git a/vulkano/src/command_buffer/mod.rs b/vulkano/src/command_buffer/mod.rs index 88e719c8b8..4515904325 100644 --- a/vulkano/src/command_buffer/mod.rs +++ b/vulkano/src/command_buffer/mod.rs @@ -1617,6 +1617,7 @@ pub enum ResourceInCommand { SecondaryCommandBuffer { index: u32 }, Source, VertexBuffer { binding: u32 }, + ShaderBindingTableBuffer, } #[doc(hidden)] diff --git a/vulkano/src/device/mod.rs b/vulkano/src/device/mod.rs index 62f8f40006..caf6a3eee6 100644 --- a/vulkano/src/device/mod.rs +++ b/vulkano/src/device/mod.rs @@ -114,6 +114,7 @@ use crate::{ instance::{Instance, InstanceOwned, InstanceOwnedDebugWrapper}, macros::{impl_id_counter, vulkan_bitflags}, memory::{ExternalMemoryHandleType, MemoryFdProperties, MemoryRequirements}, + pipeline::ray_tracing::RayTracingPipeline, Requires, RequiresAllOf, RequiresOneOf, Validated, ValidationError, Version, VulkanError, VulkanObject, }; @@ -1304,6 +1305,63 @@ impl Device { Ok(()) } + + pub fn get_ray_tracing_shader_group_handles( + &self, + ray_tracing_pipeline: &RayTracingPipeline, + first_group: u32, + group_count: u32, + ) -> Result> { + if !self.enabled_features().ray_tracing_pipeline + || self + .physical_device() + .properties() + .shader_group_handle_size + .is_none() + { + Err(Box::new(ValidationError { + problem: "device property `shader_group_handle_size` is empty".into(), + requires_one_of: RequiresOneOf(&[RequiresAllOf(&[Requires::DeviceFeature( + "ray_tracing_pipeline", + )])]), + ..Default::default() + }))?; + }; + + if (first_group + group_count) as usize > ray_tracing_pipeline.groups().len() { + Err(Box::new(ValidationError { + problem: "the sum of `first_group` and `group_count` must be less than or equal\ + to the number of shader groups in pipeline" + .into(), + vuids: &["VUID-vkGetRayTracingShaderGroupHandlesKHR-firstGroup-02419"], + ..Default::default() + }))? + } + // TODO: VUID-vkGetRayTracingShaderGroupHandlesKHR-pipeline-07828 + + let handle_size = self + .physical_device() + .properties() + .shader_group_handle_size + .unwrap(); + + let mut data = vec![0u8; (handle_size * group_count) as usize]; + let fns = self.fns(); + unsafe { + (fns.khr_ray_tracing_pipeline + .get_ray_tracing_shader_group_handles_khr)( + self.handle, + ray_tracing_pipeline.handle(), + first_group, + group_count, + data.len(), + data.as_mut_ptr().cast(), + ) + .result() + .map_err(VulkanError::from)?; + } + Ok(ShaderGroupHandlesData { data, handle_size }) + } } impl Debug for Device { @@ -2134,6 +2192,50 @@ impl Deref for DeviceOwnedDebugWrapper { } } +#[derive(Clone, Debug)] +pub struct ShaderGroupHandlesData { + data: Vec, + handle_size: u32, +} + +impl ShaderGroupHandlesData { + pub fn handle_size(&self) -> u32 { + self.handle_size + } +} + +pub struct ShaderGroupHandlesDataIter<'a> { + data: &'a [u8], + handle_size: usize, + index: usize, +} + +impl<'a> Iterator for ShaderGroupHandlesDataIter<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + if self.index >= self.data.len() { + None + } else { + let end = self.index + self.handle_size; + let slice = &self.data[self.index..end]; + self.index = end; + Some(slice) + } + } +} +impl<'a> ExactSizeIterator for ShaderGroupHandlesDataIter<'a> {} + +impl ShaderGroupHandlesData { + pub fn iter(&self) -> ShaderGroupHandlesDataIter<'_> { + ShaderGroupHandlesDataIter { + data: &self.data, + handle_size: self.handle_size as usize, + index: 0, + } + } +} + #[cfg(test)] mod tests { use crate::device::{ diff --git a/vulkano/src/pipeline/compute.rs b/vulkano/src/pipeline/compute.rs index 8766d0f08c..eb5853da7d 100644 --- a/vulkano/src/pipeline/compute.rs +++ b/vulkano/src/pipeline/compute.rs @@ -57,7 +57,7 @@ impl ComputePipeline { cache: Option>, create_info: ComputePipelineCreateInfo, ) -> Result, Validated> { - Self::validate_new(&device, cache.as_ref().map(AsRef::as_ref), &create_info)?; + Self::validate_new(&device, cache.as_deref(), &create_info)?; Ok(unsafe { Self::new_unchecked(device, cache, create_info) }?) } diff --git a/vulkano/src/pipeline/graphics/mod.rs b/vulkano/src/pipeline/graphics/mod.rs index a90633b2f4..db24f1ce24 100644 --- a/vulkano/src/pipeline/graphics/mod.rs +++ b/vulkano/src/pipeline/graphics/mod.rs @@ -178,7 +178,7 @@ impl GraphicsPipeline { cache: Option>, create_info: GraphicsPipelineCreateInfo, ) -> Result, Validated> { - Self::validate_new(&device, cache.as_ref().map(AsRef::as_ref), &create_info)?; + Self::validate_new(&device, cache.as_deref(), &create_info)?; Ok(unsafe { Self::new_unchecked(device, cache, create_info) }?) } diff --git a/vulkano/src/pipeline/mod.rs b/vulkano/src/pipeline/mod.rs index 435a6f70eb..68a56e94ee 100644 --- a/vulkano/src/pipeline/mod.rs +++ b/vulkano/src/pipeline/mod.rs @@ -23,6 +23,7 @@ pub mod cache; pub mod compute; pub mod graphics; pub mod layout; +pub mod ray_tracing; pub(crate) mod shader; /// A trait for operations shared between pipeline types. @@ -60,13 +61,13 @@ vulkan_enum! { // TODO: document Graphics = GRAPHICS, - /* TODO: enable + // TODO: document RayTracing = RAY_TRACING_KHR RequiresOneOf([ RequiresAllOf([DeviceExtension(khr_ray_tracing_pipeline)]), RequiresAllOf([DeviceExtension(nv_ray_tracing)]), - ]),*/ + ]), /* TODO: enable // TODO: document diff --git a/vulkano/src/pipeline/ray_tracing/mod.rs b/vulkano/src/pipeline/ray_tracing/mod.rs new file mode 100644 index 0000000000..aa244dd9ae --- /dev/null +++ b/vulkano/src/pipeline/ray_tracing/mod.rs @@ -0,0 +1,592 @@ +use std::{collections::hash_map::Entry, mem::MaybeUninit, num::NonZeroU64, ptr, sync::Arc}; + +use ahash::{HashMap, HashSet}; +use ash::vk::StridedDeviceAddressRegionKHR; +use smallvec::SmallVec; + +use crate::{ + buffer::{Buffer, BufferCreateInfo, BufferUsage, Subbuffer}, + device::{Device, DeviceOwned, DeviceOwnedDebugWrapper, DeviceOwnedVulkanObject}, + instance::InstanceOwnedDebugWrapper, + macros::impl_id_counter, + memory::{ + allocator::{align_up, AllocationCreateInfo, MemoryAllocator, MemoryTypeFilter}, + DeviceAlignment, + }, + shader::DescriptorBindingRequirements, + Validated, ValidationError, VulkanError, VulkanObject, +}; + +use super::{ + cache::PipelineCache, DynamicState, Pipeline, PipelineBindPoint, PipelineCreateFlags, + PipelineLayout, PipelineShaderStageCreateInfo, PipelineShaderStageCreateInfoExtensionsVk, + PipelineShaderStageCreateInfoFields1Vk, PipelineShaderStageCreateInfoFields2Vk, +}; + +#[derive(Debug)] +pub struct RayTracingPipeline { + handle: ash::vk::Pipeline, + device: InstanceOwnedDebugWrapper>, + id: NonZeroU64, + + flags: PipelineCreateFlags, + layout: DeviceOwnedDebugWrapper>, + + descriptor_binding_requirements: HashMap<(u32, u32), DescriptorBindingRequirements>, + num_used_descriptor_sets: u32, + + groups: SmallVec<[RayTracingShaderGroupCreateInfo; 5]>, + stages: SmallVec<[PipelineShaderStageCreateInfo; 5]>, +} + +impl RayTracingPipeline { + /// Creates a new `RayTracingPipeline`. + #[inline] + pub fn new( + device: Arc, + cache: Option>, + create_info: RayTracingPipelineCreateInfo, + ) -> Result, Validated> { + // Self::validate_new(&device, cache.as_deref(), &create_info)?; + + unsafe { Ok(Self::new_unchecked(device, cache, create_info)?) } + } + + fn validate_new( + device: &Device, + cache: Option<&PipelineCache>, + create_info: &RayTracingPipelineCreateInfo, + ) -> Result<(), Box> { + todo!() + } + + #[cfg_attr(not(feature = "document_unchecked"), doc(hidden))] + pub unsafe fn new_unchecked( + device: Arc, + cache: Option>, + create_info: RayTracingPipelineCreateInfo, + ) -> Result, VulkanError> { + let handle = { + let fields3_vk = create_info.to_vk_fields3(); + let fields2_vk = create_info.to_vk_fields2(&fields3_vk); + let mut fields1_extensions_vk = create_info.to_vk_fields1_extensions(); + let fields1_vk = create_info.to_vk_fields1(&fields2_vk, &mut fields1_extensions_vk); + let create_infos_vk = create_info.to_vk(&fields1_vk); + + let fns = device.fns(); + let mut output = MaybeUninit::uninit(); + + (fns.khr_ray_tracing_pipeline + .create_ray_tracing_pipelines_khr)( + device.handle(), + ash::vk::DeferredOperationKHR::null(), // TODO: RayTracing: deferred_operation + cache.map_or(ash::vk::PipelineCache::null(), |c| c.handle()), + 1, + &create_infos_vk, + ptr::null(), + output.as_mut_ptr(), + ) + .result() + .map_err(VulkanError::from)?; + output.assume_init() + }; + + Ok(Self::from_handle(device, handle, create_info)) + } + + pub unsafe fn from_handle( + device: Arc, + handle: ash::vk::Pipeline, + create_info: RayTracingPipelineCreateInfo, + ) -> Arc { + let RayTracingPipelineCreateInfo { + flags, + stages, + groups, + layout, + .. + } = create_info; + + let mut descriptor_binding_requirements: HashMap< + (u32, u32), + DescriptorBindingRequirements, + > = HashMap::default(); + for stage in &stages { + for (&loc, reqs) in stage + .entry_point + .info() + .descriptor_binding_requirements + .iter() + { + match descriptor_binding_requirements.entry(loc) { + Entry::Occupied(entry) => { + entry.into_mut().merge(reqs).expect("Could not produce an intersection of the shader descriptor requirements"); + } + Entry::Vacant(entry) => { + entry.insert(reqs.clone()); + } + } + } + } + let num_used_descriptor_sets = descriptor_binding_requirements + .keys() + .map(|loc| loc.0) + .max() + .map(|x| x + 1) + .unwrap_or(0); + Arc::new(Self { + handle, + device: InstanceOwnedDebugWrapper(device), + id: Self::next_id(), + + flags, + layout: DeviceOwnedDebugWrapper(layout), + + descriptor_binding_requirements, + num_used_descriptor_sets, + + groups, + stages, + }) + } + + pub fn groups(&self) -> &[RayTracingShaderGroupCreateInfo] { + &self.groups + } + + pub fn stages(&self) -> &[PipelineShaderStageCreateInfo] { + &self.stages + } + + pub fn device(&self) -> &Arc { + &self.device + } +} + +impl Pipeline for RayTracingPipeline { + #[inline] + fn bind_point(&self) -> PipelineBindPoint { + PipelineBindPoint::RayTracing + } + + #[inline] + fn layout(&self) -> &Arc { + &self.layout + } + + #[inline] + fn num_used_descriptor_sets(&self) -> u32 { + self.num_used_descriptor_sets + } + + #[inline] + fn descriptor_binding_requirements( + &self, + ) -> &HashMap<(u32, u32), DescriptorBindingRequirements> { + &self.descriptor_binding_requirements + } +} + +impl_id_counter!(RayTracingPipeline); + +unsafe impl VulkanObject for RayTracingPipeline { + type Handle = ash::vk::Pipeline; + + #[inline] + fn handle(&self) -> Self::Handle { + self.handle + } +} + +unsafe impl DeviceOwned for RayTracingPipeline { + #[inline] + fn device(&self) -> &Arc { + self.device() + } +} + +impl Drop for RayTracingPipeline { + #[inline] + fn drop(&mut self) { + unsafe { + let fns = self.device.fns(); + (fns.v1_0.destroy_pipeline)(self.device.handle(), self.handle, ptr::null()); + } + } +} + +/// Parameters to create a new `RayTracingPipeline`. +#[derive(Clone, Debug)] +pub struct RayTracingPipelineCreateInfo { + /// Additional properties of the pipeline. + /// + /// The default value is empty. + pub flags: PipelineCreateFlags, + + /// The compute shader stage to use. + /// + /// There is no default value. + pub stages: SmallVec<[PipelineShaderStageCreateInfo; 5]>, + + pub groups: SmallVec<[RayTracingShaderGroupCreateInfo; 5]>, + + pub max_pipeline_ray_recursion_depth: u32, + + pub dynamic_state: HashSet, + + /// The pipeline layout to use. + /// + /// There is no default value. + pub layout: Arc, + + /// The pipeline to use as a base when creating this pipeline. + /// + /// If this is `Some`, then `flags` must contain [`PipelineCreateFlags::DERIVATIVE`], + /// and the `flags` of the provided pipeline must contain + /// [`PipelineCreateFlags::ALLOW_DERIVATIVES`]. + /// + /// The default value is `None`. + pub base_pipeline: Option>, + + pub _ne: crate::NonExhaustive, +} + +impl RayTracingPipelineCreateInfo { + pub fn layout(layout: Arc) -> Self { + Self { + flags: PipelineCreateFlags::empty(), + stages: SmallVec::new(), + groups: SmallVec::new(), + max_pipeline_ray_recursion_depth: 0, + dynamic_state: Default::default(), + + layout, + + base_pipeline: None, + _ne: crate::NonExhaustive(()), + } + } + + pub(crate) fn to_vk<'a>( + &self, + fields1_vk: &'a RayTracingPipelineCreateInfoFields1Vk<'_>, + ) -> ash::vk::RayTracingPipelineCreateInfoKHR<'a> { + let &Self { + flags, + max_pipeline_ray_recursion_depth, + + ref layout, + ref base_pipeline, + .. + } = self; + + let RayTracingPipelineCreateInfoFields1Vk { + stages_vk, + groups_vk, + dynamic_state_vk, + } = fields1_vk; + + let mut val_vk = ash::vk::RayTracingPipelineCreateInfoKHR::default() + .flags(flags.into()) + .stages(stages_vk) + .groups(groups_vk) + .layout(layout.handle()) + .max_pipeline_ray_recursion_depth(max_pipeline_ray_recursion_depth) + .base_pipeline_handle( + base_pipeline + .as_ref() + .map_or(ash::vk::Pipeline::null(), |p| p.handle()), + ) + .base_pipeline_index(-1); + + if let Some(dynamic_state_vk) = dynamic_state_vk { + val_vk = val_vk.dynamic_state(dynamic_state_vk); + } + + return val_vk; + } + + pub(crate) fn to_vk_fields1<'a>( + &self, + fields2_vk: &'a RayTracingPipelineCreateInfoFields2Vk<'_>, + extensions_vk: &'a mut RayTracingPipelineCreateInfoFields1ExtensionsVk, + ) -> RayTracingPipelineCreateInfoFields1Vk<'a> { + let Self { stages, groups, .. } = self; + let RayTracingPipelineCreateInfoFields2Vk { + stages_fields1_vk, + dynamic_states_vk, + } = fields2_vk; + let RayTracingPipelineCreateInfoFields1ExtensionsVk { + stages_extensions_vk, + } = extensions_vk; + + let stages_vk: SmallVec<[_; 5]> = stages + .iter() + .zip(stages_fields1_vk) + .zip(stages_extensions_vk) + .map(|((stage, fields1), fields1_extensions_vk)| { + stage.to_vk(fields1, fields1_extensions_vk) + }) + .collect(); + + let groups_vk = groups + .iter() + .map(RayTracingShaderGroupCreateInfo::to_vk) + .collect(); + + let dynamic_state_vk = (!dynamic_states_vk.is_empty()).then(|| { + ash::vk::PipelineDynamicStateCreateInfo::default() + .flags(ash::vk::PipelineDynamicStateCreateFlags::empty()) + .dynamic_states(dynamic_states_vk) + }); + + RayTracingPipelineCreateInfoFields1Vk { + stages_vk, + groups_vk, + dynamic_state_vk, + } + } + + pub(crate) fn to_vk_fields1_extensions( + &self, + ) -> RayTracingPipelineCreateInfoFields1ExtensionsVk { + let Self { stages, .. } = self; + + let stages_extensions_vk = stages + .iter() + .map(|stage| stage.to_vk_extensions()) + .collect(); + + RayTracingPipelineCreateInfoFields1ExtensionsVk { + stages_extensions_vk, + } + } + + pub(crate) fn to_vk_fields2<'a>( + &self, + fields3_vk: &'a RayTracingPipelineCreateInfoFields3Vk, + ) -> RayTracingPipelineCreateInfoFields2Vk<'a> { + let Self { + stages, + dynamic_state, + .. + } = self; + + let stages_fields1_vk = stages + .iter() + .zip(fields3_vk.stages_fields2_vk.iter()) + .map(|(stage, fields3)| stage.to_vk_fields1(fields3)) + .collect(); + + let dynamic_states_vk = dynamic_state.iter().copied().map(Into::into).collect(); + + RayTracingPipelineCreateInfoFields2Vk { + stages_fields1_vk, + dynamic_states_vk, + } + } + + pub(crate) fn to_vk_fields3<'a>(&self) -> RayTracingPipelineCreateInfoFields3Vk { + let Self { stages, .. } = self; + + let stages_fields2_vk = stages.iter().map(|stage| stage.to_vk_fields2()).collect(); + + RayTracingPipelineCreateInfoFields3Vk { stages_fields2_vk } + } +} + +#[derive(Clone, Debug, Default)] +pub struct RayTracingShaderGroupCreateInfo { + pub group_type: ash::vk::RayTracingShaderGroupTypeKHR, // TODO: Custom type + pub general_shader: Option, + pub closest_hit_shader: Option, + pub any_hit_shader: Option, + pub intersection_shader: Option, +} + +impl RayTracingShaderGroupCreateInfo { + pub(crate) fn to_vk(&self) -> ash::vk::RayTracingShaderGroupCreateInfoKHR<'static> { + // We are not using pointers in the struct, so 'static is used. + ash::vk::RayTracingShaderGroupCreateInfoKHR::default() + .ty(self.group_type) + .general_shader(self.general_shader.unwrap_or(ash::vk::SHADER_UNUSED_KHR)) + .closest_hit_shader( + self.closest_hit_shader + .unwrap_or(ash::vk::SHADER_UNUSED_KHR), + ) + .any_hit_shader(self.any_hit_shader.unwrap_or(ash::vk::SHADER_UNUSED_KHR)) + .intersection_shader( + self.intersection_shader + .unwrap_or(ash::vk::SHADER_UNUSED_KHR), + ) + } +} + +pub struct RayTracingPipelineCreateInfoFields1Vk<'a> { + pub(crate) stages_vk: SmallVec<[ash::vk::PipelineShaderStageCreateInfo<'a>; 5]>, + pub(crate) groups_vk: SmallVec<[ash::vk::RayTracingShaderGroupCreateInfoKHR<'static>; 5]>, + pub(crate) dynamic_state_vk: Option>, +} + +pub struct RayTracingPipelineCreateInfoFields1ExtensionsVk { + pub(crate) stages_extensions_vk: SmallVec<[PipelineShaderStageCreateInfoExtensionsVk; 5]>, +} + +pub struct RayTracingPipelineCreateInfoFields2Vk<'a> { + pub(crate) stages_fields1_vk: SmallVec<[PipelineShaderStageCreateInfoFields1Vk<'a>; 5]>, + pub(crate) dynamic_states_vk: SmallVec<[ash::vk::DynamicState; 4]>, +} + +pub struct RayTracingPipelineCreateInfoFields3Vk { + pub(crate) stages_fields2_vk: SmallVec<[PipelineShaderStageCreateInfoFields2Vk; 5]>, +} +#[derive(Debug, Clone)] +pub struct ShaderBindingTable { + raygen: StridedDeviceAddressRegionKHR, + miss: StridedDeviceAddressRegionKHR, + hit: StridedDeviceAddressRegionKHR, + callable: StridedDeviceAddressRegionKHR, + buffer: Subbuffer<[u8]>, +} + +impl ShaderBindingTable { + pub fn raygen(&self) -> &StridedDeviceAddressRegionKHR { + &self.raygen + } + + pub fn miss(&self) -> &StridedDeviceAddressRegionKHR { + &self.miss + } + + pub fn hit(&self) -> &StridedDeviceAddressRegionKHR { + &self.hit + } + + pub fn callable(&self) -> &StridedDeviceAddressRegionKHR { + &self.callable + } + + pub(crate) fn buffer(&self) -> &Subbuffer<[u8]> { + &self.buffer + } + + pub fn new( + allocator: Arc, + ray_tracing_pipeline: &RayTracingPipeline, + miss_shader_count: u64, + hit_shader_count: u64, + callable_shader_count: u64, + ) -> Result> { + let handle_data = ray_tracing_pipeline + .device() + .get_ray_tracing_shader_group_handles( + &ray_tracing_pipeline, + 0, + ray_tracing_pipeline.groups().len() as u32, + )?; + + let properties = ray_tracing_pipeline.device().physical_device().properties(); + let handle_size_aligned = align_up( + handle_data.handle_size() as u64, + DeviceAlignment::new(properties.shader_group_handle_alignment.unwrap() as u64) + .expect("unexpected shader_group_handle_alignment"), + ); + + let shader_group_base_alignment = + DeviceAlignment::new(properties.shader_group_base_alignment.unwrap() as u64) + .expect("unexpected shader_group_base_alignment"); + + let raygen_stride = align_up(handle_size_aligned, shader_group_base_alignment); + + let mut raygen = StridedDeviceAddressRegionKHR { + stride: raygen_stride, + size: raygen_stride, + device_address: 0, + }; + let mut miss = StridedDeviceAddressRegionKHR { + stride: handle_size_aligned, + size: align_up( + handle_size_aligned * miss_shader_count, + shader_group_base_alignment, + ), + device_address: 0, + }; + let mut hit = StridedDeviceAddressRegionKHR { + stride: handle_size_aligned, + size: align_up( + handle_size_aligned * hit_shader_count, + shader_group_base_alignment, + ), + device_address: 0, + }; + let mut callable = StridedDeviceAddressRegionKHR { + stride: handle_size_aligned, + size: align_up( + handle_size_aligned * callable_shader_count, + shader_group_base_alignment, + ), + device_address: 0, + }; + + let sbt_buffer = Buffer::new_slice::( + allocator, + BufferCreateInfo { + usage: BufferUsage::TRANSFER_SRC + | BufferUsage::SHADER_DEVICE_ADDRESS + | BufferUsage::SHADER_BINDING_TABLE, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::HOST_SEQUENTIAL_WRITE + | MemoryTypeFilter::PREFER_DEVICE, + ..Default::default() + }, + raygen.size + miss.size + hit.size + callable.size, + ) + .expect("todo: raytracing: better error type"); + sbt_buffer + .buffer() + .set_debug_utils_object_name("Shader Binding Table Buffer".into()) + .unwrap(); + + raygen.device_address = sbt_buffer.buffer().device_address().unwrap().get(); + miss.device_address = raygen.device_address + raygen.size; + hit.device_address = miss.device_address + miss.size; + callable.device_address = hit.device_address + hit.size; + + { + let mut sbt_buffer_write = sbt_buffer.write().unwrap(); + + let mut handle_iter = handle_data.iter(); + + let handle_size = handle_data.handle_size() as usize; + sbt_buffer_write[..handle_size].copy_from_slice(handle_iter.next().unwrap()); + let mut offset = raygen.size as usize; + for _ in 0..miss_shader_count { + sbt_buffer_write[offset..offset + handle_size] + .copy_from_slice(handle_iter.next().unwrap()); + offset += miss.stride as usize; + } + offset = (raygen.size + miss.size) as usize; + for _ in 0..hit_shader_count { + sbt_buffer_write[offset..offset + handle_size] + .copy_from_slice(handle_iter.next().unwrap()); + offset += hit.stride as usize; + } + offset = (raygen.size + miss.size + hit.size) as usize; + for _ in 0..callable_shader_count { + sbt_buffer_write[offset..offset + handle_size] + .copy_from_slice(handle_iter.next().unwrap()); + offset += callable.stride as usize; + } + } + + Ok(Self { + raygen, + miss, + hit, + callable, + buffer: sbt_buffer, + }) + } +}