Skip to content

Commit

Permalink
ray_tracing_pipeline: Add miss shader
Browse files Browse the repository at this point in the history
  • Loading branch information
bwrsandman committed Apr 28, 2020
1 parent fda9d85 commit b6f7878
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 19 deletions.
30 changes: 28 additions & 2 deletions examples/src/bin/nv-ray-tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,39 @@ void main() {
const vec2 pixelCenter = coord + vec2(0.5);
const vec2 inUV = pixelCenter / vec2(gl_LaunchSizeNV.xy);
payload = vec4(inUV.x, inUV.y, 1.0f, 1.0f);
float aspect = float(gl_LaunchSizeNV.x) / gl_LaunchSizeNV.y;
vec3 lower_left_corner = vec3(-aspect, 1.0, -1.0);
vec3 horizontal = vec3(2.0 * aspect, 0.0, 0.0);
vec3 vertical = vec3(0.0, -2.0, 0.0);
vec3 origin = vec3(0.0, 0.0, 0.0);
vec3 direction = normalize(lower_left_corner + inUV.x * horizontal + inUV.y * vertical);
traceNV(scene, gl_RayFlagsOpaqueNV, 0xFF, 0, 0, 0, origin, 0.001, direction, 1000.0, 0);
imageStore(result, coord, payload);
}
"
}
}
let rs = rs::Shader::load(device.clone()).unwrap();

mod ms {
vulkano_shaders::shader! {
ty: "miss",
src: "#version 460 core
#extension GL_NV_ray_tracing : enable
layout(location = 0) rayPayloadInNV vec4 payload;
void main() {
vec3 unit_direction = normalize(gl_WorldRayDirectionNV);
float t = 0.5 * (unit_direction.y + 1.0);
payload = vec4(mix(vec3(0.5, 0.7, 1.0), vec3(1.0, 1.0, 1.0), t), 1.0f);
}
"
}
}
let ms = ms::Shader::load(device.clone()).unwrap();

// We set a limit to the recursion of a ray so that the shader does not run infinitely
let max_recursion_depth = 5;

Expand All @@ -176,6 +201,7 @@ void main() {
// We need at least one ray generation shader to describe where rays go
// and to store the result of their path tracing
.raygen_shader(rs.main_entry_point(), ())
.miss_shader(ms.main_entry_point(), ())
.build(device.clone())
.unwrap(),
);
Expand Down Expand Up @@ -222,7 +248,7 @@ void main() {
)
.unwrap();
let (miss_shader_binding_table, miss_buffer_future) = ImmutableBuffer::from_iter(
(0..0).map(|_| 5u8),
group_handles[group_handle_size..2 * group_handle_size].iter().copied(),
BufferUsage::ray_tracing(),
queue.clone(),
)
Expand Down
122 changes: 107 additions & 15 deletions vulkano/src/pipeline/ray_tracing_pipeline/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,21 @@ use VulkanObject;

/// Prototype for a `RayTracingPipeline`.
// TODO: we can optimize this by filling directly the raw vk structs
pub struct RayTracingPipelineBuilder<Rs, Rss> {
pub struct RayTracingPipelineBuilder<Rs, Rss, Ms, Mss> {
nv_extension: bool,
raygen_shader: Option<(Rs, Rss)>,
// TODO: Should be a list
miss_shader: Option<(Ms, Mss)>,
max_recursion_depth: u32,
}

impl RayTracingPipelineBuilder<EmptyEntryPointDummy, ()> {
impl RayTracingPipelineBuilder<EmptyEntryPointDummy, (), EmptyEntryPointDummy, ()> {
/// Builds a new empty builder using the `nv_ray_tracing` extension.
pub(super) fn nv(max_recursion_depth: u32) -> Self {
RayTracingPipelineBuilder {
nv_extension: true,
raygen_shader: None,
miss_shader: None,
max_recursion_depth,
}
}
Expand All @@ -52,16 +55,20 @@ impl RayTracingPipelineBuilder<EmptyEntryPointDummy, ()> {
RayTracingPipelineBuilder {
nv_extension: false,
raygen_shader: None,
miss_shader: None,
max_recursion_depth,
}
}
}

impl<Rs, Rss> RayTracingPipelineBuilder<Rs, Rss>
impl<Rs, Rss, Ms, Mss> RayTracingPipelineBuilder<Rs, Rss, Ms, Mss>
where
Rs: RayTracingEntryPointAbstract,
Rss: SpecializationConstants,
Rs::PipelineLayout: Clone + 'static + Send + Sync, // TODO: shouldn't be required
Ms: RayTracingEntryPointAbstract,
Mss: SpecializationConstants,
Ms::PipelineLayout: Clone + 'static + Send + Sync, // TODO: shouldn't be required
{
/// Builds the ray tracing pipeline, using an inferred a pipeline layout.
// TODO: replace Box<PipelineLayoutAbstract> with a PipelineUnion struct without template params
Expand Down Expand Up @@ -102,15 +109,28 @@ where
let pipeline_layout;

// Must be at least one stage with raygen
if let Some(ref raygen_shader) = self.raygen_shader {
pipeline_layout = Box::new(
PipelineLayoutDescTweaks::new(
self.raygen_shader.as_ref().unwrap().0.layout().clone(),
dynamic_buffers.into_iter().cloned(),
)
.build(device.clone())
.unwrap(),
) as Box<_>; // TODO: error
if let Some((ref shader, _)) = self.raygen_shader {
let layout = shader.layout().clone();
if let Some((ref shader, _)) = self.miss_shader {
let layout = layout.union(shader.layout().clone());
pipeline_layout = Box::new(
PipelineLayoutDescTweaks::new(
layout,
dynamic_buffers.into_iter().cloned(),
)
.build(device.clone())
.unwrap(),
) as Box<_>; // TODO: error
} else {
pipeline_layout = Box::new(
PipelineLayoutDescTweaks::new(
layout,
dynamic_buffers.into_iter().cloned(),
)
.build(device.clone())
.unwrap(),
) as Box<_>; // TODO: error
}
} else {
return Err(RayTracingPipelineCreationError::NoRaygenShader);
}
Expand Down Expand Up @@ -175,6 +195,31 @@ where
pName: self.raygen_shader.as_ref().unwrap().0.name().as_ptr(),
pSpecializationInfo: &raygen_shader_specialization as *const _,
};
let miss_shader_specialization = {
let spec_descriptors = Mss::descriptors();
let constants = &self.miss_shader.as_ref().unwrap().1;
vk::SpecializationInfo {
mapEntryCount: spec_descriptors.len() as u32,
pMapEntries: spec_descriptors.as_ptr() as *const _,
dataSize: mem::size_of_val(constants),
pData: constants as *const Mss as *const _,
}
};
let miss_stage = vk::PipelineShaderStageCreateInfo {
sType: vk::STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
pNext: ptr::null(),
flags: 0, // reserved
stage: vk::SHADER_STAGE_MISS_BIT_NV,
module: self
.miss_shader
.as_ref()
.unwrap()
.0
.module()
.internal_object(),
pName: self.miss_shader.as_ref().unwrap().0.name().as_ptr(),
pSpecializationInfo: &miss_shader_specialization as *const _,
};

let (pipeline, group_count) = if self.nv_extension {
let mut stages = SmallVec::<[_; 1]>::new();
Expand All @@ -192,6 +237,17 @@ where
});
stages.push(raygen_stage);

// Miss
groups.push(vk::RayTracingShaderGroupCreateInfoNV {
sType: vk::STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_NV,
pNext: ptr::null(),
type_: vk::RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_NV,
generalShader: stages.len() as u32,
closestHitShader: vk::SHADER_UNUSED,
anyHitShader: vk::SHADER_UNUSED,
intersectionShader: vk::SHADER_UNUSED,
});
stages.push(miss_stage);
unsafe {
let infos = vk::RayTracingPipelineCreateInfoNV {
sType: vk::STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_NV,
Expand Down Expand Up @@ -235,6 +291,19 @@ where
});
stages.push(raygen_stage);

// Miss
groups.push(vk::RayTracingShaderGroupCreateInfoKHR {
sType: vk::STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
pNext: ptr::null(),
type_: vk::RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
generalShader: stages.len() as u32,
closestHitShader: vk::SHADER_UNUSED,
anyHitShader: vk::SHADER_UNUSED,
intersectionShader: vk::SHADER_UNUSED,
pShaderGroupCaptureReplayHandle: ptr::null(), // TODO:
});
stages.push(miss_stage);

let library_info = vk::PipelineLibraryCreateInfoKHR {
sType: vk::STRUCTURE_TYPE_PIPELINE_LIBRARY_CREATE_INFO_KHR,
pNext: ptr::null(),
Expand Down Expand Up @@ -294,7 +363,7 @@ where
// TODO: add build_with_cache method
}

impl<Rs1, Rss1> RayTracingPipelineBuilder<Rs1, Rss1> {
impl<Rs1, Rss1, Ms1, Mss1> RayTracingPipelineBuilder<Rs1, Rss1, Ms1, Mss1> {
// TODO: add pipeline derivate system

/// Adds a raygen shader group to use.
Expand All @@ -305,7 +374,7 @@ impl<Rs1, Rss1> RayTracingPipelineBuilder<Rs1, Rss1> {
self,
shader: Rs2,
specialization_constants: Rss2,
) -> RayTracingPipelineBuilder<Rs2, Rss2>
) -> RayTracingPipelineBuilder<Rs2, Rss2, Ms1, Mss1>
where
Rs2: RayTracingEntryPointAbstract<SpecializationConstants = Rss2>,
Rss2: SpecializationConstants,
Expand All @@ -314,19 +383,42 @@ impl<Rs1, Rss1> RayTracingPipelineBuilder<Rs1, Rss1> {
nv_extension: self.nv_extension,
max_recursion_depth: self.max_recursion_depth,
raygen_shader: Some((shader, specialization_constants)),
miss_shader: self.miss_shader,
}
}

/// Adds a miss shader group to use.
// TODO: miss_shader should be a list
// TODO: correct specialization constants
#[inline]
pub fn miss_shader<Ms2, Mss2>(
self, shader: Ms2, specialization_constants: Mss2,
) -> RayTracingPipelineBuilder<Rs1, Rss1, Ms2, Mss2>
where
Ms2: RayTracingEntryPointAbstract<SpecializationConstants = Mss2>,
Mss2: SpecializationConstants,
{
RayTracingPipelineBuilder {
nv_extension: self.nv_extension,
max_recursion_depth: self.max_recursion_depth,
raygen_shader: self.raygen_shader,
miss_shader: Some((shader, specialization_constants)),
}
}
}

impl<Rs, Rss> Clone for RayTracingPipelineBuilder<Rs, Rss>
impl<Rs, Rss, Ms, Mss> Clone for RayTracingPipelineBuilder<Rs, Rss, Ms, Mss>
where
Rs: Clone,
Rss: Clone,
Ms: Clone,
Mss: Clone,
{
fn clone(&self) -> Self {
RayTracingPipelineBuilder {
nv_extension: self.nv_extension,
raygen_shader: self.raygen_shader.clone(),
miss_shader: self.miss_shader.clone(),
max_recursion_depth: self.max_recursion_depth,
}
}
Expand Down
4 changes: 2 additions & 2 deletions vulkano/src/pipeline/ray_tracing_pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl RayTracingPipeline<()> {
/// Returns a builder object that you can fill with the various parameters.
pub fn nv<'a>(
max_recursion_depth: u32,
) -> RayTracingPipelineBuilder<EmptyEntryPointDummy, ()> {
) -> RayTracingPipelineBuilder<EmptyEntryPointDummy, (), EmptyEntryPointDummy, ()> {
RayTracingPipelineBuilder::nv(max_recursion_depth)
}

Expand All @@ -66,7 +66,7 @@ impl RayTracingPipeline<()> {
/// Returns a builder object that you can fill with the various parameters.
pub fn khr<'a>(
max_recursion_depth: u32,
) -> RayTracingPipelineBuilder<EmptyEntryPointDummy, ()> {
) -> RayTracingPipelineBuilder<EmptyEntryPointDummy, (), EmptyEntryPointDummy, ()> {
RayTracingPipelineBuilder::khr(max_recursion_depth)
}
}
Expand Down

0 comments on commit b6f7878

Please sign in to comment.