From c3aed40e84f9509d3e108e4c60b429c812ba10c5 Mon Sep 17 00:00:00 2001 From: Philip Degarmo Date: Sun, 4 Aug 2024 10:29:41 -0700 Subject: [PATCH] Mesh Shader Support in raft-api for DX12 and Metal (#258) --- rafx-api/Cargo.toml | 8 +- rafx-api/src/backends/dx12/command_buffer.rs | 21 +- .../src/backends/dx12/internal/conversions.rs | 2 +- .../dx12/internal/mipmap_resources.rs | 7 +- rafx-api/src/backends/dx12/pipeline.rs | 388 ++++++++++++++--- rafx-api/src/backends/dx12/root_signature.rs | 14 +- rafx-api/src/backends/dx12/texture.rs | 8 +- rafx-api/src/backends/empty.rs | 5 +- rafx-api/src/backends/metal/command_buffer.rs | 61 ++- rafx-api/src/backends/metal/pipeline.rs | 236 +++++++--- rafx-api/src/backends/metal/shader_module.rs | 3 +- .../src/backends/vulkan/root_signature.rs | 4 - rafx-api/src/command_buffer.rs | 52 +++ rafx-api/src/extra/indirect.rs | 3 +- rafx-api/src/reflection.rs | 6 + rafx-api/src/types/definitions.rs | 2 +- rafx-api/src/types/format.rs | 6 +- rafx-api/src/types/misc.rs | 8 +- .../importers/material_importer.rs | 2 +- .../importers/material_instance_importer.rs | 4 +- .../assets/image/builder_compressed_image.rs | 4 +- rafx-base/src/trust_cell.rs | 2 - rafx-framework/src/graph/graph_pass.rs | 16 - .../src/assets/mesh_adv/mesh_adv_jobs.rs | 8 +- .../mesh_adv/internal/frame_packet.rs | 6 - rafx/Cargo.toml | 5 + .../meshshader_triangle.rs | 403 ++++++++++++++++++ .../meshshader_triangle/shaders/shaders.hlsl | 51 +++ .../meshshader_triangle/shaders/shaders.metal | 34 ++ 29 files changed, 1175 insertions(+), 194 deletions(-) create mode 100644 rafx/examples/meshshader_triangle/meshshader_triangle.rs create mode 100644 rafx/examples/meshshader_triangle/shaders/shaders.hlsl create mode 100644 rafx/examples/meshshader_triangle/shaders/shaders.metal diff --git a/rafx-api/Cargo.toml b/rafx-api/Cargo.toml index d23c6f06f..f12dae80e 100644 --- a/rafx-api/Cargo.toml +++ b/rafx-api/Cargo.toml @@ -30,19 +30,19 @@ backtrace = { version = "0.3", optional = true } raw-window-handle = "0.5" # vulkan/dx12 -gpu-allocator = { version = "0.22.0", default_features = false, optional = true } +gpu-allocator = { version = "0.22.0", default-features = false, optional = true } # vulkan ash = { version = "0.37", optional = true } ash-window = { version = "0.12", optional = true } # dx12 -windows = { version = "0.44", optional=true, features = ["Win32_Foundation", "Win32_Graphics_Dxgi_Common", "Win32_Security", "Win32_System", "Win32_System_Threading", "Win32_Graphics_Direct3D", "Win32_Graphics_Direct3D12", "Win32_Graphics_Dxgi", "Win32_Graphics_Direct3D_Dxc"] } -hassle-rs = { version = "0.10.0", optional=true } +windows = { version = "0.44", optional = true, features = ["Win32_Foundation", "Win32_Graphics_Dxgi_Common", "Win32_Security", "Win32_System", "Win32_System_Threading", "Win32_Graphics_Direct3D", "Win32_Graphics_Direct3D12", "Win32_Graphics_Dxgi", "Win32_Graphics_Direct3D_Dxc"] } +hassle-rs = { version = "0.10.0", optional = true } # metal [target.'cfg(target_os="macos")'.dependencies] -metal_rs = { package = "metal", version = "0.25", optional = true } +metal_rs = { package = "metal", version = "0.28", optional = true } core-graphics-types = { version = "0.1", optional = true } # Force core-graphics-0.22.3 due to semver breakage # https://github.com/servo/core-foundation-rs/pull/562 diff --git a/rafx-api/src/backends/dx12/command_buffer.rs b/rafx-api/src/backends/dx12/command_buffer.rs index 3805fa1cc..c938f8df1 100644 --- a/rafx-api/src/backends/dx12/command_buffer.rs +++ b/rafx-api/src/backends/dx12/command_buffer.rs @@ -23,7 +23,7 @@ use super::d3d12; pub struct RafxCommandBufferDx12Inner { //command_list_type: d3d12::D3D12_COMMAND_LIST_TYPE, command_list_base: d3d12::ID3D12CommandList, - command_list: d3d12::ID3D12GraphicsCommandList, + command_list: d3d12::ID3D12GraphicsCommandList6, command_allocator: d3d12::ID3D12CommandAllocator, bound_root_signature: Option, @@ -45,7 +45,7 @@ impl RafxCommandBufferDx12 { self.inner.borrow().command_list_base.clone() } - pub fn dx12_graphics_command_list(&self) -> d3d12::ID3D12GraphicsCommandList { + pub fn dx12_graphics_command_list(&self) -> d3d12::ID3D12GraphicsCommandList6 { self.inner.borrow().command_list.clone() } @@ -87,7 +87,7 @@ impl RafxCommandBufferDx12 { //TODO: Special handling for copy? let command_list_type = command_pool.command_list_type(); let command_list = unsafe { - let command_list: d3d12::ID3D12GraphicsCommandList = command_pool + let command_list: d3d12::ID3D12GraphicsCommandList6 = command_pool .queue() .device_context() .d3d12_device() @@ -716,6 +716,21 @@ impl RafxCommandBufferDx12 { Ok(()) } + pub fn cmd_draw_mesh( + &self, + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, + ) -> RafxResult<()> { + let inner = self.inner.borrow(); + unsafe { + inner + .command_list + .DispatchMesh(group_count_x, group_count_y, group_count_z); + } + Ok(()) + } + pub fn cmd_dispatch( &self, group_count_x: u32, diff --git a/rafx-api/src/backends/dx12/internal/conversions.rs b/rafx-api/src/backends/dx12/internal/conversions.rs index 4eb36f3df..eda8f18da 100644 --- a/rafx-api/src/backends/dx12/internal/conversions.rs +++ b/rafx-api/src/backends/dx12/internal/conversions.rs @@ -222,7 +222,7 @@ pub fn blend_state_blend_state_desc( || def.src_factor_alpha != RafxBlendFactor::One || def.dst_factor_alpha != RafxBlendFactor::Zero; - let mut desc = &mut blend_desc.RenderTarget[attachment_index as usize]; + let desc = &mut blend_desc.RenderTarget[attachment_index as usize]; desc.BlendEnable = blend_enable.into(); desc.RenderTargetWriteMask = def.masks.bits(); desc.BlendOp = def.blend_op.into(); diff --git a/rafx-api/src/backends/dx12/internal/mipmap_resources.rs b/rafx-api/src/backends/dx12/internal/mipmap_resources.rs index f2b2c6b4f..b9ea30efe 100644 --- a/rafx-api/src/backends/dx12/internal/mipmap_resources.rs +++ b/rafx-api/src/backends/dx12/internal/mipmap_resources.rs @@ -1,11 +1,6 @@ use super::d3d12; use crate::dx12::RafxDeviceContextDx12; -use crate::{ - RafxComputePipelineDef, RafxImmutableSamplerKey, RafxImmutableSamplers, RafxPipeline, - RafxResourceType, RafxResult, RafxRootSignature, RafxRootSignatureDef, RafxSampler, - RafxSamplerDef, RafxShader, RafxShaderModule, RafxShaderModuleDefDx12, RafxShaderResource, - RafxShaderStageDef, RafxShaderStageFlags, RafxShaderStageReflection, -}; +use crate::RafxResult; pub struct Dx12MipmapResources { //pub shader: RafxShader, diff --git a/rafx-api/src/backends/dx12/pipeline.rs b/rafx-api/src/backends/dx12/pipeline.rs index 55372d77f..5848ccfa7 100644 --- a/rafx-api/src/backends/dx12/pipeline.rs +++ b/rafx-api/src/backends/dx12/pipeline.rs @@ -8,6 +8,211 @@ use crate::{ MAX_RENDER_TARGET_ATTACHMENTS, }; use std::ffi::CString; +use windows::core::Vtable; + +macro_rules! pipeline_state_stream_subobject { + ($struct_name:ident, $constant:expr, $inner_type:ty) => { + #[repr(C, align(8))] + struct $struct_name { + subobject_type: d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE, + inner: $inner_type, + } + + impl Default for $struct_name { + fn default() -> Self { + Self { + subobject_type: $constant, + inner: <$inner_type>::default(), + } + } + } + }; +} + +macro_rules! pipeline_state_stream_subobject_with_default { + ($struct_name:ident, $constant:expr, $inner_type:ty, $default_value:expr) => { + #[repr(C, align(8))] + struct $struct_name { + subobject_type: d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE, + inner: $inner_type, + } + + impl Default for $struct_name { + fn default() -> Self { + Self { + subobject_type: $constant, + inner: $default_value, + } + } + } + }; +} + +pipeline_state_stream_subobject!( + PipelineStateStreamFlags, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_FLAGS, + d3d12::D3D12_PIPELINE_STATE_FLAGS +); +pipeline_state_stream_subobject!( + PipelineStateStreamNodeMask, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_NODE_MASK, + u32 +); +pipeline_state_stream_subobject_with_default!( + PipelineStateStreamRootSignature, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_ROOT_SIGNATURE, + *const d3d12::ID3D12RootSignature, + std::ptr::null_mut() +); +pipeline_state_stream_subobject!( + PipelineStateStreamInputLayout, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_INPUT_LAYOUT, + d3d12::D3D12_INPUT_LAYOUT_DESC +); +pipeline_state_stream_subobject!( + PipelineStateStreamIbStripCutValue, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_IB_STRIP_CUT_VALUE, + d3d12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE +); +pipeline_state_stream_subobject!( + PipelineStateStreamPrimitiveTopologyType, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PRIMITIVE_TOPOLOGY, + d3d12::D3D12_PRIMITIVE_TOPOLOGY_TYPE +); +pipeline_state_stream_subobject!( + PipelineStateStreamVS, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_VS, + d3d12::D3D12_SHADER_BYTECODE +); +pipeline_state_stream_subobject!( + PipelineStateStreamGS, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_GS, + d3d12::D3D12_SHADER_BYTECODE +); +pipeline_state_stream_subobject!( + PipelineStateStreamStreamOutput, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_STREAM_OUTPUT, + d3d12::D3D12_STREAM_OUTPUT_DESC +); +pipeline_state_stream_subobject!( + PipelineStateStreamHS, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_HS, + d3d12::D3D12_SHADER_BYTECODE +); +pipeline_state_stream_subobject!( + PipelineStateStreamDS, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DS, + d3d12::D3D12_SHADER_BYTECODE +); +pipeline_state_stream_subobject!( + PipelineStateStreamPS, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PS, + d3d12::D3D12_SHADER_BYTECODE +); +pipeline_state_stream_subobject!( + PipelineStateStreamAS, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_AS, + d3d12::D3D12_SHADER_BYTECODE +); +pipeline_state_stream_subobject!( + PipelineStateStreamMS, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_MS, + d3d12::D3D12_SHADER_BYTECODE +); +pipeline_state_stream_subobject!( + PipelineStateStreamCS, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_CS, + d3d12::D3D12_SHADER_BYTECODE +); +pipeline_state_stream_subobject!( + PipelineStateStreamBlendDesc, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_BLEND, + d3d12::D3D12_BLEND_DESC +); +pipeline_state_stream_subobject!( + PipelineStateStreamDepthStencil, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL, + d3d12::D3D12_DEPTH_STENCIL_DESC +); +pipeline_state_stream_subobject!( + PipelineStateStreamDepthStencil1, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL1, + d3d12::D3D12_DEPTH_STENCIL_DESC1 +); +// if (D3D12_SDK_VERSION >= 606) +//pipeline_state_stream_subobject!(PipelineStateStreamDepthStencil2, d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL2, d3d12::D3D12_DEPTH_STENCIL_DESC2); +pipeline_state_stream_subobject!( + PipelineStateStreamDepthStencilFormat, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL_FORMAT, + dxgi::Common::DXGI_FORMAT +); +pipeline_state_stream_subobject!( + PipelineStateStreamRasterizer, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RASTERIZER, + d3d12::D3D12_RASTERIZER_DESC +); +// if (D3D12_SDK_VERSION >= 608) +//pipeline_state_stream_subobject!(PipelineStateStreamRasterizer1, d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RASTERIZER1, d3d12::D3D12_RASTERIZER_DESC1); +// if (D3D12_SDK_VERSION >= 610) +//pipeline_state_stream_subobject!(PipelineStateStreamRasterizer2, d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RASTERIZER2, d3d12::D3D12_RASTERIZER_DESC2); +pipeline_state_stream_subobject!( + PipelineStateStreamRenderTargetFormats, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RENDER_TARGET_FORMATS, + d3d12::D3D12_RT_FORMAT_ARRAY +); +pipeline_state_stream_subobject!( + PipelineStateStreamSampleDesc, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_DESC, + dxgi::Common::DXGI_SAMPLE_DESC +); +pipeline_state_stream_subobject!( + PipelineStateStreamSampleMask, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_MASK, + u32 +); +pipeline_state_stream_subobject!( + PipelineStateStreamCachedPso, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_CACHED_PSO, + d3d12::D3D12_CACHED_PIPELINE_STATE +); +pipeline_state_stream_subobject!( + PipelineStateStreamViewInstancing, + d3d12::D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_VIEW_INSTANCING, + d3d12::D3D12_VIEW_INSTANCING_DESC +); + +#[derive(Default)] +#[repr(C)] +struct PipelineStreamObjectMesh { + flags: PipelineStateStreamFlags, + node_mask: PipelineStateStreamNodeMask, + root_signature: PipelineStateStreamRootSignature, + //input_layout: PipelineStateStreamInputLayout, + //ib_strip_cut_value: PipelineStateStreamIbStripCutValue, + primitive_topology_type: PipelineStateStreamPrimitiveTopologyType, + //vs: PipelineStateStreamVS, + //gs: PipelineStateStreamGS, + stream_output: PipelineStateStreamStreamOutput, + //hs: PipelineStateStreamHS, + //ds: PipelineStateStreamDS, + ps: PipelineStateStreamPS, + r#as: PipelineStateStreamAS, + ms: PipelineStateStreamMS, + //cs: PipelineStateStreamCS, + blend: PipelineStateStreamBlendDesc, + depth_stencil: PipelineStateStreamDepthStencil, + //depth_stencil1: PipelineStateStreamDepthStencil1, + //depth_stencil2: PipelineStateStreamDepthStencil2, + dsv_format: PipelineStateStreamDepthStencilFormat, + rasterizer: PipelineStateStreamRasterizer, + //rasterizer1: PipelineStateStreamRasterizer1, + //rasterizer2: PipelineStateStreamRasterizer2, + rtv_formats: PipelineStateStreamRenderTargetFormats, + sample_desc: PipelineStateStreamSampleDesc, + sample_mask: PipelineStateStreamSampleMask, + cached_pso: PipelineStateStreamCachedPso, + view_instancing: PipelineStateStreamViewInstancing, +} #[derive(Debug)] pub struct RafxPipelineDx12 { @@ -74,6 +279,8 @@ impl RafxPipelineDx12 { let mut ds_bytecode = None; let mut hs_bytecode = None; let mut gs_bytecode = None; + let mut ms_bytecode = None; + let mut as_bytecode = None; for stage in pipeline_def.shader.dx12_shader().unwrap().stages() { let module = stage.shader_module.dx12_shader_module().unwrap(); @@ -127,10 +334,26 @@ impl RafxPipelineDx12 { module.get_or_compile_bytecode(&stage.reflection.entry_point_name, "gs_6_0")?, ); } - //stage.reflection.shader_stage; - // somehow get bytecode? reflection defines entry point and type of shader - // probably query the shader module, it compiles and caches. we could have pre-compiled - // and look it up as well + + if stage + .reflection + .shader_stage + .intersects(RafxShaderStageFlags::MESH) + { + ms_bytecode = Some( + module.get_or_compile_bytecode(&stage.reflection.entry_point_name, "ms_6_5")?, + ); + } + + if stage + .reflection + .shader_stage + .intersects(RafxShaderStageFlags::AMPLIFICATION) + { + as_bytecode = Some( + module.get_or_compile_bytecode(&stage.reflection.entry_point_name, "as_6_5")?, + ); + } } // can leave everything zero'd out @@ -245,54 +468,123 @@ impl RafxPipelineDx12 { rtv_formats[i] = pipeline_def.color_formats[i].into(); } - let pipeline_state_desc = d3d12::D3D12_GRAPHICS_PIPELINE_STATE_DESC { - pRootSignature: ::windows::core::ManuallyDrop::new( - &pipeline_def - .root_signature - .dx12_root_signature() - .unwrap() - .dx12_root_signature() - .clone(), - ), - VS: vs_bytecode.map(|x| *x.bytecode()).unwrap_or_default(), - PS: ps_bytecode.map(|x| *x.bytecode()).unwrap_or_default(), - DS: ds_bytecode.map(|x| *x.bytecode()).unwrap_or_default(), - GS: gs_bytecode.map(|x| *x.bytecode()).unwrap_or_default(), - HS: hs_bytecode.map(|x| *x.bytecode()).unwrap_or_default(), - StreamOutput: stream_out_desc, - BlendState: super::internal::conversions::blend_state_blend_state_desc( - pipeline_def.blend_state, - render_target_count, - ), - SampleMask: u32::MAX, - RasterizerState: super::internal::conversions::rasterizer_state_rasterizer_desc( - pipeline_def.rasterizer_state, - ), - DepthStencilState: depth_stencil_desc, //super::internal::conversions::depth_state_depth_stencil_desc(pipeline_def.depth_state), - InputLayout: input_layout_desc, - IBStripCutValue: d3d12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_DISABLED, - PrimitiveTopologyType: pipeline_def.primitive_topology.into(), - NumRenderTargets: render_target_count as u32, - RTVFormats: rtv_formats, - DSVFormat: pipeline_def + let blend_state = super::internal::conversions::blend_state_blend_state_desc( + pipeline_def.blend_state, + render_target_count, + ); + + let rasterizer_state = super::internal::conversions::rasterizer_state_rasterizer_desc( + pipeline_def.rasterizer_state, + ); + + let pipeline_state = if ms_bytecode.is_some() { + // Treat as a graphics pipeline using mesh shaders + use windows::core::Interface; + let device2 = device_context + .d3d12_device() + .cast::() + .unwrap(); + + //let dx12_root_sig = pipeline_def.root_signature.dx12_root_signature().unwrap().dx12_root_signature(); + let root_sig_ptr = pipeline_def + .root_signature + .dx12_root_signature() + .unwrap() + .dx12_root_signature() + .as_raw(); + + let mut pipeline_stream_object = PipelineStreamObjectMesh::default(); + + pipeline_stream_object.root_signature.inner = + root_sig_ptr as *const d3d12::ID3D12RootSignature; + pipeline_stream_object.r#as.inner = + as_bytecode.map(|x| *x.bytecode()).unwrap_or_default(); + pipeline_stream_object.ms.inner = + ms_bytecode.map(|x| *x.bytecode()).unwrap_or_default(); + pipeline_stream_object.ps.inner = + ps_bytecode.map(|x| *x.bytecode()).unwrap_or_default(); + pipeline_stream_object.blend.inner = blend_state; + pipeline_stream_object.sample_mask.inner = u32::MAX; + pipeline_stream_object.rasterizer.inner = rasterizer_state; + pipeline_stream_object.depth_stencil.inner = depth_stencil_desc; + pipeline_stream_object.primitive_topology_type.inner = + pipeline_def.primitive_topology.into(); + pipeline_stream_object.rtv_formats.inner.NumRenderTargets = render_target_count as u32; + pipeline_stream_object.rtv_formats.inner.RTFormats = rtv_formats; + pipeline_stream_object.dsv_format.inner = pipeline_def .depth_stencil_format .map(|x| x.into()) - .unwrap_or(dxgi::Common::DXGI_FORMAT_UNKNOWN), - SampleDesc: sample_desc, - CachedPSO: cached_pipeline_state, - Flags: d3d12::D3D12_PIPELINE_STATE_FLAG_NONE, - NodeMask: 0, - }; + .unwrap_or(dxgi::Common::DXGI_FORMAT_UNKNOWN); + pipeline_stream_object.sample_desc.inner = sample_desc; + pipeline_stream_object.cached_pso.inner = cached_pipeline_state; + pipeline_stream_object.flags.inner = d3d12::D3D12_PIPELINE_STATE_FLAG_NONE; + pipeline_stream_object.node_mask.inner = 0; + + //pipeline_stream_object.vs.inner = vs_bytecode.map(|x| *x.bytecode()).as_ref().unwrap_or_default(); + let pipeline_state_desc = d3d12::D3D12_PIPELINE_STATE_STREAM_DESC { + SizeInBytes: std::mem::size_of::(), + pPipelineStateSubobjectStream: ((&mut pipeline_stream_object) + as *mut PipelineStreamObjectMesh) + as *mut std::ffi::c_void, + }; + let pipeline_state: d3d12::ID3D12PipelineState = unsafe { + device2 + .CreatePipelineState( + &pipeline_state_desc as *const d3d12::D3D12_PIPELINE_STATE_STREAM_DESC, + ) + .unwrap() + }; - //TODO: More hashing required if using PSO cache + pipeline_state + } else { + // Treat as a standard graphics pipeline + + let pipeline_state_desc = d3d12::D3D12_GRAPHICS_PIPELINE_STATE_DESC { + pRootSignature: ::windows::core::ManuallyDrop::new( + &pipeline_def + .root_signature + .dx12_root_signature() + .unwrap() + .dx12_root_signature() + .clone(), + ), + VS: vs_bytecode.map(|x| *x.bytecode()).unwrap_or_default(), + PS: ps_bytecode.map(|x| *x.bytecode()).unwrap_or_default(), + DS: ds_bytecode.map(|x| *x.bytecode()).unwrap_or_default(), + GS: gs_bytecode.map(|x| *x.bytecode()).unwrap_or_default(), + HS: hs_bytecode.map(|x| *x.bytecode()).unwrap_or_default(), + StreamOutput: stream_out_desc, + BlendState: blend_state, + SampleMask: u32::MAX, + RasterizerState: rasterizer_state, + DepthStencilState: depth_stencil_desc, //super::internal::conversions::depth_state_depth_stencil_desc(pipeline_def.depth_state), + InputLayout: input_layout_desc, + IBStripCutValue: d3d12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_DISABLED, + PrimitiveTopologyType: pipeline_def.primitive_topology.into(), + NumRenderTargets: render_target_count as u32, + RTVFormats: rtv_formats, + DSVFormat: pipeline_def + .depth_stencil_format + .map(|x| x.into()) + .unwrap_or(dxgi::Common::DXGI_FORMAT_UNKNOWN), + SampleDesc: sample_desc, + CachedPSO: cached_pipeline_state, + Flags: d3d12::D3D12_PIPELINE_STATE_FLAG_NONE, + NodeMask: 0, + }; - //TODO: Try to find cached PSO + //TODO: More hashing required if using PSO cache - // If we didn't have it cached, build it - let pipeline: d3d12::ID3D12PipelineState = unsafe { - device_context - .d3d12_device() - .CreateGraphicsPipelineState(&pipeline_state_desc)? + //TODO: Try to find cached PSO + + // If we didn't have it cached, build it + let pipeline_state: d3d12::ID3D12PipelineState = unsafe { + device_context + .d3d12_device() + .CreateGraphicsPipelineState(&pipeline_state_desc)? + }; + + pipeline_state }; let topology = pipeline_def.primitive_topology.into(); @@ -300,7 +592,7 @@ impl RafxPipelineDx12 { let pipeline = RafxPipelineDx12 { root_signature: pipeline_def.root_signature.clone(), pipeline_type: pipeline_def.root_signature.pipeline_type(), - pipeline, + pipeline: pipeline_state, topology, vertex_buffer_strides, }; diff --git a/rafx-api/src/backends/dx12/root_signature.rs b/rafx-api/src/backends/dx12/root_signature.rs index 9dd5e7c5d..40347a602 100644 --- a/rafx-api/src/backends/dx12/root_signature.rs +++ b/rafx-api/src/backends/dx12/root_signature.rs @@ -562,7 +562,17 @@ impl RafxRootSignatureDx12 { if !all_used_shader_stage.intersects(RafxShaderStageFlags::FRAGMENT) { root_signature_flags |= d3d12::D3D12_ROOT_SIGNATURE_FLAG_DENY_PIXEL_SHADER_ROOT_ACCESS; } - // There are other deny flags we could use? + + //NOTE: PIX and renderdoc will fail to debug mesh shaders if this flag is enabled + // because they rely on instrumenting the shader and writing resources. This is likely + // fixed as of ~Aug 2024 in renderdoc. But realistically these deny flags only really help + // old hardware anyways. + // if !all_used_shader_stage.intersects(RafxShaderStageFlags::MESH) { + // root_signature_flags |= d3d12::D3D12_ROOT_SIGNATURE_FLAG_DENY_MESH_SHADER_ROOT_ACCESS; + // } + // if !all_used_shader_stage.intersects(RafxShaderStageFlags::AMPLIFICATION) { + // root_signature_flags |= d3d12::D3D12_ROOT_SIGNATURE_FLAG_DENY_AMPLIFICATION_SHADER_ROOT_ACCESS; + // } // // Make the root signature @@ -605,7 +615,7 @@ impl RafxRootSignatureDx12 { root_sig_string.GetBufferPointer() as *const u8, root_sig_string.GetBufferSize(), ); - let str = String::from_utf8_lossy(sig_string); + //let str = String::from_utf8_lossy(sig_string); //println!("root sig {}", str); device_context diff --git a/rafx-api/src/backends/dx12/texture.rs b/rafx-api/src/backends/dx12/texture.rs index b71fd6707..7394482fe 100644 --- a/rafx-api/src/backends/dx12/texture.rs +++ b/rafx-api/src/backends/dx12/texture.rs @@ -775,7 +775,7 @@ impl RafxTextureDx12 { Flags: d3d12::D3D12_RESOURCE_FLAG_NONE, }; - let mut resource_states = RafxResourceState::UNDEFINED; + let resource_states = RafxResourceState::UNDEFINED; if create_uav_chain { desc.Flags |= d3d12::D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS; @@ -891,7 +891,7 @@ impl RafxTextureDx12 { if create_uav_chain { srv_uav_handle_count += texture_def.mip_count; } - let mut resource_desc = unsafe { image.image.GetDesc() }; + let resource_desc = unsafe { image.image.GetDesc() }; let is_cube_map = texture_def .resource_type @@ -1047,8 +1047,8 @@ impl RafxTextureDx12 { next_dsv_handle = Some(next_dsv_handle.unwrap().add_offset(1)); } - let mut first_rtv_slice = next_rtv_handle; - let mut first_dsv_slice = next_dsv_handle; + let first_rtv_slice = next_rtv_handle; + let first_dsv_slice = next_dsv_handle; for mip_level in 0..texture_def.mip_count { if texture_def diff --git a/rafx-api/src/backends/empty.rs b/rafx-api/src/backends/empty.rs index c33257be7..22e81591a 100644 --- a/rafx-api/src/backends/empty.rs +++ b/rafx-api/src/backends/empty.rs @@ -14,7 +14,7 @@ impl RafxApiEmpty { pub fn destroy(&mut self) -> RafxResult<()> { unimplemented!() } } - + #[derive(Clone)] pub struct RafxDeviceContextEmpty; impl RafxDeviceContextEmpty { @@ -160,8 +160,9 @@ impl RafxCommandBufferEmpty { pub fn cmd_draw_indexed_instanced(&self, index_count: u32, first_index: u32, instance_count: u32, first_instance: u32, vertex_offset: i32) -> RafxResult<()> { unimplemented!() } pub fn cmd_draw_indirect(&self, indirect_buffer: &RafxBufferEmpty, indirect_buffer_offset_in_bytes: u32, draw_count: u32) -> RafxResult<()> { unimplemented!() } pub fn cmd_draw_indexed_indirect(&self, indirect_buffer: &RafxBufferEmpty, indirect_buffer_offset_in_bytes: u32, draw_count: u32) -> RafxResult<()> { unimplemented!() } + pub fn cmd_draw_mesh(&self, group_count_x: u32, group_count_y: u32, group_count_z: u32) -> RafxResult<()> { unimplemented!() } - pub fn cmd_dispatch(&self, group_count_x: u32, group_count_y: u32, group_count_z: u32) -> RafxResult<()> { unimplemented!() } + pub fn cmd_dispatch(&self, group_count_x: u32, group_count_y: u32, group_count_z: u32) -> RafxResult<()> { unimplemented!() } pub fn cmd_resource_barrier(&self, buffer_barriers: &[RafxBufferBarrier], texture_barriers: &[RafxTextureBarrier]) -> RafxResult<()> { unimplemented!() } pub fn cmd_copy_buffer_to_buffer(&self, src_buffer: &RafxBufferEmpty, dst_buffer: &RafxBufferEmpty, params: &RafxCmdCopyBufferToBufferParams) -> RafxResult<()> { unimplemented!() } diff --git a/rafx-api/src/backends/metal/command_buffer.rs b/rafx-api/src/backends/metal/command_buffer.rs index d46704886..73fb5e67a 100644 --- a/rafx-api/src/backends/metal/command_buffer.rs +++ b/rafx-api/src/backends/metal/command_buffer.rs @@ -35,9 +35,14 @@ pub struct RafxCommandBufferMetalInner { primitive_type: MTLPrimitiveType, current_render_targets_width: u32, current_render_targets_height: u32, - compute_threads_per_group_x: u32, - compute_threads_per_group_y: u32, - compute_threads_per_group_z: u32, + + // Thread group size for compute + threads_per_compute_threadgroup: MTLSize, + + // Tracks thread group size for mesh-shader based render pipelines + threads_per_object_threadgroup: MTLSize, + threads_per_mesh_threadgroup: MTLSize, + group_debug_name_stack: Vec, debug_names_enabled: bool, } @@ -110,9 +115,9 @@ impl RafxCommandBufferMetal { primitive_type: MTLPrimitiveType::Triangle, current_render_targets_width: 0, current_render_targets_height: 0, - compute_threads_per_group_x: 0, - compute_threads_per_group_y: 0, - compute_threads_per_group_z: 0, + threads_per_compute_threadgroup: MTLSize::default(), + threads_per_object_threadgroup: MTLSize::default(), + threads_per_mesh_threadgroup: MTLSize::default(), current_index_buffer: None, current_index_buffer_byte_offset: 0, current_index_buffer_type: MTLIndexType::UInt16, @@ -460,6 +465,12 @@ impl RafxCommandBufferMetal { } inner.primitive_type = render_encoder_info.mtl_primitive_type; + + inner.threads_per_object_threadgroup = + render_encoder_info.threads_per_object_threadgroup; + inner.threads_per_mesh_threadgroup = + render_encoder_info.threads_per_mesh_threadgroup; + self.flush_render_targets_to_make_readable(&mut *inner); } RafxPipelineType::Compute => { @@ -486,10 +497,8 @@ impl RafxCommandBufferMetal { } let compute_encoder_info = pipeline.compute_encoder_info.as_ref().unwrap(); - let compute_threads_per_group = compute_encoder_info.compute_threads_per_group; - inner.compute_threads_per_group_x = compute_threads_per_group[0]; - inner.compute_threads_per_group_y = compute_threads_per_group[1]; - inner.compute_threads_per_group_z = compute_threads_per_group[2]; + inner.threads_per_compute_threadgroup = + compute_encoder_info.threads_per_threadgroup; inner .compute_encoder @@ -932,6 +941,31 @@ impl RafxCommandBufferMetal { Ok(()) } + pub fn cmd_draw_mesh( + &self, + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, + ) -> RafxResult<()> { + let inner = self.inner.borrow(); + + let group_count = MTLSize { + width: (group_count_x as metal_rs::NSUInteger) + * inner.threads_per_mesh_threadgroup.width, + height: (group_count_y as metal_rs::NSUInteger) + * inner.threads_per_mesh_threadgroup.height, + depth: (group_count_z as metal_rs::NSUInteger) + * inner.threads_per_mesh_threadgroup.depth, + }; + + inner.render_encoder.as_ref().unwrap().draw_mesh_threads( + group_count, + inner.threads_per_object_threadgroup, + inner.threads_per_mesh_threadgroup, + ); + Ok(()) + } + pub fn cmd_dispatch( &self, group_count_x: u32, @@ -940,11 +974,6 @@ impl RafxCommandBufferMetal { ) -> RafxResult<()> { let inner = self.inner.borrow(); self.wait_for_barriers(&*inner)?; - let thread_per_group = MTLSize { - width: inner.compute_threads_per_group_x as _, - height: inner.compute_threads_per_group_y as _, - depth: inner.compute_threads_per_group_z as _, - }; let group_count = MTLSize { width: group_count_x as _, @@ -956,7 +985,7 @@ impl RafxCommandBufferMetal { .compute_encoder .as_ref() .unwrap() - .dispatch_thread_groups(group_count, thread_per_group); + .dispatch_thread_groups(group_count, inner.threads_per_compute_threadgroup); Ok(()) } diff --git a/rafx-api/src/backends/metal/pipeline.rs b/rafx-api/src/backends/metal/pipeline.rs index bc744ac0c..ccd31acbb 100644 --- a/rafx-api/src/backends/metal/pipeline.rs +++ b/rafx-api/src/backends/metal/pipeline.rs @@ -4,6 +4,18 @@ use crate::{ RafxRootSignature, RafxShaderStageFlags, }; +fn threads_per_group_to_mtl_size( + compute_threads_per_group: Option<[u32; 3]> +) -> RafxResult { + let compute_threads_per_group = compute_threads_per_group + .ok_or("Metal shaders must have threadgroup size specified in reflection data")?; + Ok(metal_rs::MTLSize::new( + compute_threads_per_group[0] as _, + compute_threads_per_group[1] as _, + compute_threads_per_group[2] as _, + )) +} + fn metal_entry_point_name(name: &str) -> &str { // "main" is not an allowed entry point name. spirv_cross adds a 0 to the end of any // unallowed entry point names so do that here too @@ -26,7 +38,7 @@ unsafe impl Sync for MetalPipelineState {} #[derive(Debug)] pub(crate) struct PipelineComputeEncoderInfo { - pub compute_threads_per_group: [u32; 3], + pub(crate) threads_per_threadgroup: metal_rs::MTLSize, } #[derive(Debug)] @@ -40,6 +52,10 @@ pub(crate) struct PipelineRenderEncoderInfo { pub(crate) mtl_depth_clip_mode: metal_rs::MTLDepthClipMode, pub(crate) mtl_depth_stencil_state: Option, pub(crate) mtl_primitive_type: metal_rs::MTLPrimitiveType, + + // Used when pipeline is mesh-shader based + pub(crate) threads_per_object_threadgroup: metal_rs::MTLSize, + pub(crate) threads_per_mesh_threadgroup: metal_rs::MTLSize, } // for metal_rs::DepthStencilState @@ -84,16 +100,13 @@ impl RafxPipelineMetal { device_context: &RafxDeviceContextMetal, pipeline_def: &RafxGraphicsPipelineDef, ) -> RafxResult { - let pipeline = metal_rs::RenderPipelineDescriptor::new(); - - if device_context.device_info().debug_names_enabled { - if let Some(debug_name) = pipeline_def.debug_name { - pipeline.set_label(debug_name); - } - } - let mut vertex_function = None; let mut fragment_function = None; + let mut mesh_function = None; + let mut object_function = None; + + let mut threads_per_object_threadgroup = metal_rs::MTLSize::default(); + let mut threads_per_mesh_threadgroup = metal_rs::MTLSize::default(); for stage in pipeline_def.shader.metal_shader().unwrap().stages() { if stage @@ -129,66 +142,163 @@ impl RafxPipelineMetal { .get_function(entry_point, None)?, ); } - } - let vertex_function = vertex_function.ok_or("Could not find vertex function")?; - - pipeline.set_vertex_function(Some(vertex_function.as_ref())); - pipeline.set_fragment_function(fragment_function.as_deref()); - pipeline.set_sample_count(pipeline_def.sample_count.into()); - - let vertex_descriptor = metal_rs::VertexDescriptor::new(); - for attribute in &pipeline_def.vertex_layout.attributes { - let buffer_index = - super::util::vertex_buffer_adjusted_buffer_index(attribute.buffer_index); - let attribute_descriptor = vertex_descriptor - .attributes() - .object_at(attribute.location as _) - .unwrap(); - attribute_descriptor.set_buffer_index(buffer_index); - attribute_descriptor.set_format(attribute.format.into()); - attribute_descriptor.set_offset(attribute.byte_offset as _); - } + if stage + .reflection + .shader_stage + .intersects(RafxShaderStageFlags::MESH) + { + let entry_point = metal_entry_point_name(&stage.reflection.entry_point_name); + assert!(mesh_function.is_none()); + mesh_function = Some( + stage + .shader_module + .metal_shader_module() + .unwrap() + .library() + .get_function(entry_point, None)?, + ); + threads_per_mesh_threadgroup = + threads_per_group_to_mtl_size(stage.reflection.compute_threads_per_group)?; + } - for (index, binding) in pipeline_def.vertex_layout.buffers.iter().enumerate() { - let buffer_index = super::util::vertex_buffer_adjusted_buffer_index(index as u32); - let layout_descriptor = vertex_descriptor.layouts().object_at(buffer_index).unwrap(); - layout_descriptor.set_stride(binding.stride as _); - layout_descriptor.set_step_function(binding.rate.into()); - layout_descriptor.set_step_rate(1); - } - pipeline.set_vertex_descriptor(Some(vertex_descriptor)); - - pipeline.set_input_primitive_topology(pipeline_def.primitive_topology.into()); - - //TODO: Pass in number of color attachments? - super::util::blend_def_to_attachment( - pipeline_def.blend_state, - &mut pipeline.color_attachments(), - pipeline_def.color_formats.len(), - ); - - for (index, &color_format) in pipeline_def.color_formats.iter().enumerate() { - pipeline - .color_attachments() - .object_at(index as _) - .unwrap() - .set_pixel_format(color_format.into()); + if stage + .reflection + .shader_stage + .intersects(RafxShaderStageFlags::AMPLIFICATION) + { + let entry_point = metal_entry_point_name(&stage.reflection.entry_point_name); + assert!(object_function.is_none()); + object_function = Some( + stage + .shader_module + .metal_shader_module() + .unwrap() + .library() + .get_function(entry_point, None)?, + ); + threads_per_object_threadgroup = + threads_per_group_to_mtl_size(stage.reflection.compute_threads_per_group)?; + } } - if let Some(depth_format) = pipeline_def.depth_stencil_format { - if depth_format.has_depth() { - pipeline.set_depth_attachment_pixel_format(depth_format.into()); + let pipeline = if vertex_function.is_some() { + // Take the traditional vertex shader-based rasterization path + let pipeline = metal_rs::RenderPipelineDescriptor::new(); + + if device_context.device_info().debug_names_enabled { + if let Some(debug_name) = pipeline_def.debug_name { + pipeline.set_label(debug_name); + } } - if depth_format.has_stencil() { - pipeline.set_stencil_attachment_pixel_format(depth_format.into()); + pipeline.set_vertex_function(vertex_function.as_deref()); + pipeline.set_fragment_function(fragment_function.as_deref()); + pipeline.set_sample_count(pipeline_def.sample_count.into()); + + let vertex_descriptor = metal_rs::VertexDescriptor::new(); + for attribute in &pipeline_def.vertex_layout.attributes { + let buffer_index = + super::util::vertex_buffer_adjusted_buffer_index(attribute.buffer_index); + let attribute_descriptor = vertex_descriptor + .attributes() + .object_at(attribute.location as _) + .unwrap(); + attribute_descriptor.set_buffer_index(buffer_index); + attribute_descriptor.set_format(attribute.format.into()); + attribute_descriptor.set_offset(attribute.byte_offset as _); } - } - let pipeline = device_context - .device() - .new_render_pipeline_state(pipeline.as_ref())?; + for (index, binding) in pipeline_def.vertex_layout.buffers.iter().enumerate() { + let buffer_index = super::util::vertex_buffer_adjusted_buffer_index(index as u32); + let layout_descriptor = + vertex_descriptor.layouts().object_at(buffer_index).unwrap(); + layout_descriptor.set_stride(binding.stride as _); + layout_descriptor.set_step_function(binding.rate.into()); + layout_descriptor.set_step_rate(1); + } + pipeline.set_vertex_descriptor(Some(vertex_descriptor)); + + pipeline.set_input_primitive_topology(pipeline_def.primitive_topology.into()); + + // + // Shared code beyond this point for vertex/mesh path + // + //TODO: Pass in number of color attachments? + super::util::blend_def_to_attachment( + pipeline_def.blend_state, + &mut pipeline.color_attachments(), + pipeline_def.color_formats.len(), + ); + + for (index, &color_format) in pipeline_def.color_formats.iter().enumerate() { + pipeline + .color_attachments() + .object_at(index as _) + .unwrap() + .set_pixel_format(color_format.into()); + } + + if let Some(depth_format) = pipeline_def.depth_stencil_format { + if depth_format.has_depth() { + pipeline.set_depth_attachment_pixel_format(depth_format.into()); + } + + if depth_format.has_stencil() { + pipeline.set_stencil_attachment_pixel_format(depth_format.into()); + } + } + + device_context + .device() + .new_render_pipeline_state(pipeline.as_ref())? + } else if mesh_function.is_some() { + let pipeline = metal_rs::MeshRenderPipelineDescriptor::new(); + + if device_context.device_info().debug_names_enabled { + if let Some(debug_name) = pipeline_def.debug_name { + pipeline.set_label(debug_name); + } + } + + pipeline.set_object_function(object_function.as_deref()); + pipeline.set_mesh_function(mesh_function.as_deref()); + pipeline.set_fragment_function(fragment_function.as_deref()); + pipeline.set_raster_sample_count(pipeline_def.sample_count.into()); + + // + // Shared code beyond this point for vertex/mesh path + // + super::util::blend_def_to_attachment( + pipeline_def.blend_state, + &mut pipeline.color_attachments(), + pipeline_def.color_formats.len(), + ); + + for (index, &color_format) in pipeline_def.color_formats.iter().enumerate() { + pipeline + .color_attachments() + .object_at(index as _) + .unwrap() + .set_pixel_format(color_format.into()); + } + + if let Some(depth_format) = pipeline_def.depth_stencil_format { + if depth_format.has_depth() { + pipeline.set_depth_attachment_pixel_format(depth_format.into()); + } + + if depth_format.has_stencil() { + pipeline.set_stencil_attachment_pixel_format(depth_format.into()); + } + } + + device_context + .device() + .new_mesh_render_pipeline_state(pipeline.as_ref())? + } else { + Err("Could not find vertex or mesh function in the provided shader when creating pipeline")? + }; let mtl_cull_mode = pipeline_def.rasterizer_state.cull_mode.into(); let mtl_triangle_fill_mode = pipeline_def.rasterizer_state.fill_mode.into(); @@ -223,6 +333,8 @@ impl RafxPipelineMetal { mtl_depth_clip_mode, mtl_depth_stencil_state, mtl_primitive_type, + threads_per_object_threadgroup, + threads_per_mesh_threadgroup, }; Ok(RafxPipelineMetal { @@ -280,7 +392,7 @@ impl RafxPipelineMetal { .new_compute_pipeline_state(pipeline.as_ref())?; let compute_encoder_info = PipelineComputeEncoderInfo { - compute_threads_per_group: compute_threads_per_group.unwrap(), + threads_per_threadgroup: threads_per_group_to_mtl_size(compute_threads_per_group)?, }; Ok(RafxPipelineMetal { diff --git a/rafx-api/src/backends/metal/shader_module.rs b/rafx-api/src/backends/metal/shader_module.rs index 3db6264dd..c7e41fef4 100644 --- a/rafx-api/src/backends/metal/shader_module.rs +++ b/rafx-api/src/backends/metal/shader_module.rs @@ -54,7 +54,8 @@ impl RafxShaderModuleMetal { src: &str, ) -> RafxResult { let compile_options = metal_rs::CompileOptions::new(); - compile_options.set_language_version(MTLLanguageVersion::V2_1); + // 3.0 required for mesh shaders + compile_options.set_language_version(MTLLanguageVersion::V3_0); let library = device_context .device() .new_library_with_source(src, &compile_options)?; diff --git a/rafx-api/src/backends/vulkan/root_signature.rs b/rafx-api/src/backends/vulkan/root_signature.rs index 8dce305ac..2dd7b0b3b 100644 --- a/rafx-api/src/backends/vulkan/root_signature.rs +++ b/rafx-api/src/backends/vulkan/root_signature.rs @@ -4,10 +4,6 @@ use ash::vk; use fnv::FnvHashMap; use std::sync::Arc; -// Not currently exposed -#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] -pub(crate) struct DynamicDescriptorIndex(pub(crate) u32); - //TODO: Could compact this down quite a bit #[derive(Clone, Debug)] pub(crate) struct DescriptorInfo { diff --git a/rafx-api/src/command_buffer.rs b/rafx-api/src/command_buffer.rs index 34f60f3c2..61d47ac82 100644 --- a/rafx-api/src/command_buffer.rs +++ b/rafx-api/src/command_buffer.rs @@ -955,6 +955,58 @@ impl RafxCommandBuffer { } } + pub fn cmd_draw_mesh( + &self, + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, + ) -> RafxResult<()> { + match self { + #[cfg(feature = "rafx-dx12")] + RafxCommandBuffer::Dx12(inner) => { + inner.cmd_draw_mesh(group_count_x, group_count_y, group_count_z) + } + #[cfg(feature = "rafx-vulkan")] + RafxCommandBuffer::Vk(_) => { + let _ = group_count_x; + let _ = group_count_y; + let _ = group_count_z; + unimplemented!() + } + #[cfg(feature = "rafx-metal")] + RafxCommandBuffer::Metal(inner) => { + inner.cmd_draw_mesh(group_count_x, group_count_y, group_count_z) + } + #[cfg(feature = "rafx-gles2")] + RafxCommandBuffer::Gles2(_) => { + let _ = group_count_x; + let _ = group_count_y; + let _ = group_count_z; + unimplemented!() + } + #[cfg(feature = "rafx-gles3")] + RafxCommandBuffer::Gles3(_) => { + let _ = group_count_x; + let _ = group_count_y; + let _ = group_count_z; + unimplemented!() + } + #[cfg(any( + feature = "rafx-empty", + not(any( + feature = "rafx-dx12", + feature = "rafx-metal", + feature = "rafx-vulkan", + feature = "rafx-gles2", + feature = "rafx-gles3" + )) + ))] + RafxCommandBuffer::Empty(inner) => { + inner.cmd_draw_mesh(group_count_x, group_count_y, group_count_z) + } + } + } + /// Dispatch the current pipeline. Only usable with compute pipelines. pub fn cmd_dispatch( &self, diff --git a/rafx-api/src/extra/indirect.rs b/rafx-api/src/extra/indirect.rs index 0ecbbce87..5538ce223 100644 --- a/rafx-api/src/extra/indirect.rs +++ b/rafx-api/src/extra/indirect.rs @@ -210,8 +210,7 @@ impl<'a> RafxIndexedIndirectCommandEncoder<'a> { unsafe { #[cfg(feature = "rafx-dx12")] if self.is_dx12 { - let mut ptr = - self.mapped_memory as *mut RafxDrawIndexedIndirectCommandWithPushConstant; + let ptr = self.mapped_memory as *mut RafxDrawIndexedIndirectCommandWithPushConstant; let push_constant = command.first_instance; *ptr.add(index) = RafxDrawIndexedIndirectCommandWithPushConstant { command, diff --git a/rafx-api/src/reflection.rs b/rafx-api/src/reflection.rs index 5dade0477..152752284 100644 --- a/rafx-api/src/reflection.rs +++ b/rafx-api/src/reflection.rs @@ -262,7 +262,13 @@ pub struct RafxShaderStageReflection { //pub vertex_inputs: Vec, pub shader_stage: RafxShaderStageFlags, pub resources: Vec, + + // Metal needs the thread count passed in when dispatching compute or mesh shaders. So it needs + // to be provided here when working with rafx-api directly. Normally this can be populated + // automatically by shader reflection. Despite the naming, this applies for mesh and amplification + // shaders as well. It may be worth renaming this in the future pub compute_threads_per_group: Option<[u32; 3]>, + pub entry_point_name: String, // Right now we will infer mappings based on spirv_cross default behavior, but likely will want // to allow providing them explicitly. This isn't implemented yet diff --git a/rafx-api/src/types/definitions.rs b/rafx-api/src/types/definitions.rs index 8141dcfe1..ad4079979 100644 --- a/rafx-api/src/types/definitions.rs +++ b/rafx-api/src/types/definitions.rs @@ -451,7 +451,7 @@ pub struct RafxVertexLayoutBuffer { } /// Describes how vertex attributes are laid out within one or more buffers -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] pub struct RafxVertexLayout { pub attributes: Vec, pub buffers: Vec, diff --git a/rafx-api/src/types/format.rs b/rafx-api/src/types/format.rs index 580bd90a0..339f1d8fd 100644 --- a/rafx-api/src/types/format.rs +++ b/rafx-api/src/types/format.rs @@ -922,14 +922,14 @@ impl From for RafxFormat { // DxgiCommon::B10G11R11_UFLOAT_PACK32 => RafxFormat::B10G11R11_UFLOAT_PACK32, // DxgiCommon::E5B9G9R9_UFLOAT_PACK32 => RafxFormat::E5B9G9R9_UFLOAT_PACK32, DxgiCommon::DXGI_FORMAT_D16_UNORM => RafxFormat::D16_UNORM, - DxgiCommon::DXGI_FORMAT_D24_UNORM_S8_UINT => RafxFormat::X8_D24_UNORM_PACK32, + //DxgiCommon::DXGI_FORMAT_D24_UNORM_S8_UINT => RafxFormat::X8_D24_UNORM_PACK32, DxgiCommon::DXGI_FORMAT_D32_FLOAT => RafxFormat::D32_SFLOAT, // DxgiCommon::S8_UINT => RafxFormat::S8_UINT, // DxgiCommon::D16_UNORM_S8_UINT => RafxFormat::D16_UNORM_S8_UINT, DxgiCommon::DXGI_FORMAT_D24_UNORM_S8_UINT => RafxFormat::D24_UNORM_S8_UINT, DxgiCommon::DXGI_FORMAT_D32_FLOAT_S8X24_UINT => RafxFormat::D32_SFLOAT_S8_UINT, - DxgiCommon::DXGI_FORMAT_BC1_UNORM => RafxFormat::BC1_RGB_UNORM_BLOCK, - DxgiCommon::DXGI_FORMAT_BC1_UNORM_SRGB => RafxFormat::BC1_RGB_SRGB_BLOCK, + //DxgiCommon::DXGI_FORMAT_BC1_UNORM => RafxFormat::BC1_RGB_UNORM_BLOCK, + //DxgiCommon::DXGI_FORMAT_BC1_UNORM_SRGB => RafxFormat::BC1_RGB_SRGB_BLOCK, DxgiCommon::DXGI_FORMAT_BC1_UNORM => RafxFormat::BC1_RGBA_UNORM_BLOCK, DxgiCommon::DXGI_FORMAT_BC1_UNORM_SRGB => RafxFormat::BC1_RGBA_SRGB_BLOCK, DxgiCommon::DXGI_FORMAT_BC2_UNORM => RafxFormat::BC2_UNORM_BLOCK, diff --git a/rafx-api/src/types/misc.rs b/rafx-api/src/types/misc.rs index facfbf784..78fd1b769 100644 --- a/rafx-api/src/types/misc.rs +++ b/rafx-api/src/types/misc.rs @@ -380,7 +380,13 @@ bitflags::bitflags! { const GEOMETRY = 8; const FRAGMENT = 16; const COMPUTE = 32; - const ALL_GRAPHICS = 0x1F; + + // Mesh shaders + const MESH = 64; + // This is an object shader in metal + const AMPLIFICATION = 128; + + const ALL_GRAPHICS = 0xDF; const ALL = 0x7FFF_FFFF; } } diff --git a/rafx-assets/src/assets/graphics_pipeline/importers/material_importer.rs b/rafx-assets/src/assets/graphics_pipeline/importers/material_importer.rs index 0862b3ac8..d80246727 100644 --- a/rafx-assets/src/assets/graphics_pipeline/importers/material_importer.rs +++ b/rafx-assets/src/assets/graphics_pipeline/importers/material_importer.rs @@ -128,7 +128,7 @@ impl JobProcessor for MaterialJobProcessor { |handle_factory| { //let shader_module = job_system::make_handle_to_default_artifact(job_api, shader_module); let mut passes = Vec::default(); - for pass_entry in asset_data.passes().resolve_entries()?.into_iter() { + for pass_entry in &asset_data.passes().resolve_entries()? { let pass_entry = asset_data.passes().entry(*pass_entry); let fixed_function_state = diff --git a/rafx-assets/src/assets/graphics_pipeline/importers/material_instance_importer.rs b/rafx-assets/src/assets/graphics_pipeline/importers/material_instance_importer.rs index 0df8e9b32..710f4cdda 100644 --- a/rafx-assets/src/assets/graphics_pipeline/importers/material_instance_importer.rs +++ b/rafx-assets/src/assets/graphics_pipeline/importers/material_instance_importer.rs @@ -149,9 +149,7 @@ impl JobProcessor for MaterialInstanceJobProcessor { handle_factory.make_handle_to_default_artifact(asset_data.material().get()?); let mut slot_assignments = Vec::default(); - for slot_assignent_entry in - asset_data.slot_assignments().resolve_entries()?.into_iter() - { + for slot_assignent_entry in &asset_data.slot_assignments().resolve_entries()? { let slot_assignment = asset_data.slot_assignments().entry(*slot_assignent_entry); diff --git a/rafx-assets/src/assets/image/builder_compressed_image.rs b/rafx-assets/src/assets/image/builder_compressed_image.rs index 7d4b47902..bba0d5ac9 100644 --- a/rafx-assets/src/assets/image/builder_compressed_image.rs +++ b/rafx-assets/src/assets/image/builder_compressed_image.rs @@ -88,11 +88,11 @@ impl JobProcessor for GpuCompressedImageJobProcessor { }) } else { let mut layers = Vec::default(); - for &layer_entry in layer_entries.into_iter() { + for &layer_entry in &layer_entries { let layer = imported_data.data_layers().entry(layer_entry); let mip_level_entries = layer.mip_levels().resolve_entries()?; let mut mip_levels = Vec::default(); - for &mip_level_entry in mip_level_entries.into_iter() { + for &mip_level_entry in &mip_level_entries { let mip_level = layer.mip_levels().entry(mip_level_entry); mip_levels.push(ImageAssetDataMipLevel { width: mip_level.width().get()?, diff --git a/rafx-base/src/trust_cell.rs b/rafx-base/src/trust_cell.rs index 959fd332a..6e7bf63a2 100644 --- a/rafx-base/src/trust_cell.rs +++ b/rafx-base/src/trust_cell.rs @@ -12,7 +12,6 @@ use std::prelude::v1::*; -#[cfg(feature = "std")] use std::error::Error; use std::{ @@ -46,7 +45,6 @@ impl Display for InvalidBorrow { } } -#[cfg(feature = "std")] impl Error for InvalidBorrow { fn description(&self) -> &str { "This error is returned when you try to borrow immutably when it's already \ diff --git a/rafx-framework/src/graph/graph_pass.rs b/rafx-framework/src/graph/graph_pass.rs index c4e750a8c..85cafd8e9 100644 --- a/rafx-framework/src/graph/graph_pass.rs +++ b/rafx-framework/src/graph/graph_pass.rs @@ -43,14 +43,6 @@ impl RenderGraphPassBufferBarriers { } } -/// All the barriers required for a single node (i.e. subpass). Nodes represent passes that may be -/// merged to be subpasses within a single pass. -#[derive(Debug)] -pub struct RenderGraphNodeBufferBarriers { - #[allow(unused)] - pub(super) barriers: FnvHashMap, -} - pub const MAX_COLOR_ATTACHMENTS: usize = 4; pub const MAX_RESOLVE_ATTACHMENTS: usize = 4; @@ -121,14 +113,6 @@ pub struct PrepassBarrier { pub buffer_barriers: Vec, } -#[derive(Debug)] -pub struct PostpassBarrier { - // layout transition - pub image_barriers: Vec, - pub buffer_barriers: Vec, - // resolve? probably do that in rafx api level -} - #[derive(Debug)] pub struct PrepassImageBarrier { pub image: PhysicalImageId, diff --git a/rafx-plugins/src/assets/mesh_adv/mesh_adv_jobs.rs b/rafx-plugins/src/assets/mesh_adv/mesh_adv_jobs.rs index 14288cbd1..3d5e046a4 100644 --- a/rafx-plugins/src/assets/mesh_adv/mesh_adv_jobs.rs +++ b/rafx-plugins/src/assets/mesh_adv/mesh_adv_jobs.rs @@ -211,7 +211,7 @@ impl JobProcessor for MeshAdvMeshJobProcessor { // let asset_data = context.asset::(context.input.asset_id)?; let mut materials = Vec::default(); - for entry in asset_data.material_slots().resolve_entries()?.into_iter() { + for entry in &asset_data.material_slots().resolve_entries()? { let entry = asset_data.material_slots().entry(*entry).get()?; materials.push(entry); } @@ -230,7 +230,7 @@ impl JobProcessor for MeshAdvMeshJobProcessor { let mut all_indices = PushBuffer::new(16384); let mut mesh_part_data = Vec::default(); - for entry in imported_data.mesh_parts().resolve_entries()?.into_iter() { + for entry in &imported_data.mesh_parts().resolve_entries()? { let entry = imported_data.mesh_parts().entry(*entry); // @@ -347,7 +347,7 @@ impl JobProcessor for MeshAdvMeshJobProcessor { for (entry, part_data) in imported_data .mesh_parts() .resolve_entries()? - .into_iter() + .iter() .zip(mesh_part_data) { let entry = imported_data.mesh_parts().entry(*entry); @@ -461,7 +461,7 @@ impl JobProcessor for MeshAdvModelJobProcessor { context.asset::(context.input.asset_id)?; let mut lods = Vec::default(); - for entry in asset_data.lods().resolve_entries()?.into_iter() { + for entry in &asset_data.lods().resolve_entries()? { let lod = asset_data.lods().entry(*entry); let mesh_handle = handle_factory.make_handle_to_default_artifact(lod.mesh().get()?); diff --git a/rafx-plugins/src/features/mesh_adv/internal/frame_packet.rs b/rafx-plugins/src/features/mesh_adv/internal/frame_packet.rs index 257d09e01..f36c4f9c5 100644 --- a/rafx-plugins/src/features/mesh_adv/internal/frame_packet.rs +++ b/rafx-plugins/src/features/mesh_adv/internal/frame_packet.rs @@ -81,12 +81,6 @@ pub type MeshAdvFramePacket = FramePacket; // PREPARE //--------- -#[derive(Clone)] -pub struct MeshAdvPartMaterialDescriptorSetPair { - pub textured_descriptor_set: Option, - pub untextured_descriptor_set: Option, -} - #[derive(Hash, PartialEq, Eq, Clone)] pub struct MeshAdvBatchedPassKey { pub phase: RenderPhaseIndex, diff --git a/rafx/Cargo.toml b/rafx/Cargo.toml index 4155a307f..82a048daf 100644 --- a/rafx/Cargo.toml +++ b/rafx/Cargo.toml @@ -88,5 +88,10 @@ name = "asset_triangle" path = "examples/asset_triangle/asset_triangle.rs" required-features = ["assets"] +[[example]] +name = "meshshader_triangle" +path = "examples/meshshader_triangle/meshshader_triangle.rs" +required-features = [] + [package.metadata.docs.rs] features = ["rafx-vulkan", "framework", "assets", "renderer"] diff --git a/rafx/examples/meshshader_triangle/meshshader_triangle.rs b/rafx/examples/meshshader_triangle/meshshader_triangle.rs new file mode 100644 index 000000000..8fa952939 --- /dev/null +++ b/rafx/examples/meshshader_triangle/meshshader_triangle.rs @@ -0,0 +1,403 @@ +use log::LevelFilter; + +use rafx::api::*; + +const WINDOW_WIDTH: u32 = 900; +const WINDOW_HEIGHT: u32 = 600; + +fn main() { + env_logger::Builder::from_default_env() + .default_format_timestamp_nanos(true) + .filter_level(LevelFilter::Debug) + .init(); + + run().unwrap(); +} + +fn run() -> RafxResult<()> { + // + // Init SDL2 (winit and anything that uses raw-window-handle works too!) + // + let sdl2_systems = sdl2_init(); + + // + // Create the api. GPU programming is fundamentally unsafe, so all rafx APIs should be + // considered unsafe. However, rafx APIs are only gated by unsafe if they can cause undefined + // behavior on the CPU for reasons other than interacting with the GPU. + // + let mut api = unsafe { + RafxApi::new( + &sdl2_systems.window, + &sdl2_systems.window, + &Default::default(), + )? + }; + + // Wrap all of this so that it gets dropped before we drop the API object. This ensures a nice + // clean shutdown. + { + // A cloneable device handle, these are lightweight and can be passed across threads + let device_context = api.device_context(); + + // + // Allocate a graphics queue. By default, there is just one graphics queue and it is shared. + // There currently is no API for customizing this but the code would be easy to adapt to act + // differently. Most recommendations I've seen are to just use one graphics queue. (The + // rendering hardware is shared among them) + // + let graphics_queue = device_context.create_queue(RafxQueueType::Graphics)?; + + // + // Create a swapchain + // + let (window_width, window_height) = sdl2_systems.window.drawable_size(); + let swapchain = device_context.create_swapchain( + &sdl2_systems.window, + &sdl2_systems.window, + &graphics_queue, + &RafxSwapchainDef { + width: window_width, + height: window_height, + enable_vsync: true, + color_space_priority: vec![RafxSwapchainColorSpace::Srgb], + }, + )?; + + // + // Wrap the swapchain in this helper to cut down on boilerplate. This helper is + // multithreaded-rendering friendly! The PresentableFrame it returns can be sent to another + // thread and presented from there, and any errors are returned back to the main thread + // when the next image is acquired. The helper also ensures that the swapchain is rebuilt + // as necessary. + // + let mut swapchain_helper = RafxSwapchainHelper::new(&device_context, swapchain, None)?; + + // + // Some default data we can render + // + #[rustfmt::skip] + let vertex_data = [ + 0.0f32, 0.5, 1.0, 0.0, 0.0, + -0.5, -0.5, 0.0, 1.0, 0.0, + 0.5, 0.5, 0.0, 0.0, 1.0, + ]; + + let uniform_data = [1.0f32, 0.0, 1.0, 1.0]; + + // + // Create command pools/command buffers. The command pools need to be immutable while they are + // being processed by a queue, so create one per swapchain image. + // + // Create vertex buffers (with position/color information) and a uniform buffers that we + // can bind to pass additional info. + // + // In this demo, the color data in the shader is pulled from + // the uniform instead of the vertex buffer. Buffers also need to be immutable while + // processed, so we need one per swapchain image + // + let mut command_pools = Vec::with_capacity(swapchain_helper.rotating_frame_count()); + let mut command_buffers = Vec::with_capacity(swapchain_helper.rotating_frame_count()); + let mut vertex_buffers = Vec::with_capacity(swapchain_helper.rotating_frame_count()); + let mut uniform_buffers = Vec::with_capacity(swapchain_helper.rotating_frame_count()); + + for _ in 0..swapchain_helper.rotating_frame_count() { + let mut command_pool = + graphics_queue.create_command_pool(&RafxCommandPoolDef { transient: true })?; + + let command_buffer = command_pool.create_command_buffer(&RafxCommandBufferDef { + is_secondary: false, + })?; + + let vertex_buffer = device_context + .create_buffer(&RafxBufferDef::for_staging_vertex_buffer_data(&vertex_data))?; + vertex_buffer.copy_to_host_visible_buffer(&vertex_data)?; + + let uniform_buffer = device_context.create_buffer( + &RafxBufferDef::for_staging_uniform_buffer_data(&uniform_data), + )?; + uniform_buffer.copy_to_host_visible_buffer(&uniform_data)?; + + command_pools.push(command_pool); + command_buffers.push(command_buffer); + vertex_buffers.push(vertex_buffer); + uniform_buffers.push(uniform_buffer); + } + + // + // Load a shader from source - this part is API-specific. vulkan will want SPV, metal wants + // source code or even better a pre-compiled library. But the metal compiler toolchain only + // works on mac/windows and is a command line tool without programmatic access. + // + // In an engine, it would be better to pack different formats depending on the platform + // being built. Higher level rafx crates can help with this. But this is meant as a simple + // example without needing those crates. + // + // RafxShaderPackage holds all the data needed to create a GPU shader module object. It is + // heavy-weight, fully owning the data. We create by loading files from disk. This object + // can be stored as an opaque, binary object and loaded directly if you prefer. + // + // RafxShaderModuleDef is a lightweight reference to this data. Here we create it from the + // RafxShaderPackage, but you can create it yourself if you already loaded the data in some + // other way. + // + // The resulting shader modules represent a loaded shader GPU object that is used to create + // shaders. Shader modules can be discarded once the graphics pipeline is built. + // + let shaders_base_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("examples/meshshader_triangle/shaders"); + + let mut mesh_shader_package = RafxShaderPackage::default(); + let hlsl_shader_string = + std::fs::read_to_string(shaders_base_path.join("shaders.hlsl")).unwrap(); + let msl_shader_string = + std::fs::read_to_string(shaders_base_path.join("shaders.metal")).unwrap(); + mesh_shader_package.dx12 = Some(RafxShaderPackageDx12::Src(hlsl_shader_string)); + mesh_shader_package.metal = Some(RafxShaderPackageMetal::Src(msl_shader_string)); + + let mesh_shader_module = + device_context.create_shader_module(mesh_shader_package.module_def())?; + let frag_shader_module = + device_context.create_shader_module(mesh_shader_package.module_def())?; + + // + // Create the shader object by combining the stages + // + // Hardcode the reflecton data required to interact with the shaders. This can be generated + // offline and loaded with the shader but this is not currently provided in rafx-api itself. + // (But see the shader pipeline in higher-level rafx crates for example usage, generated + // from spirv_cross) + // + + let mesh_shader_stage_def = RafxShaderStageDef { + shader_module: mesh_shader_module, + reflection: RafxShaderStageReflection { + entry_point_name: "main_ms".to_string(), + shader_stage: RafxShaderStageFlags::MESH, + compute_threads_per_group: Some([128, 1, 1]), + resources: vec![], + }, + }; + + let frag_shader_stage_def = RafxShaderStageDef { + shader_module: frag_shader_module, + reflection: RafxShaderStageReflection { + entry_point_name: "main_ps".to_string(), + shader_stage: RafxShaderStageFlags::FRAGMENT, + compute_threads_per_group: None, + resources: vec![], + }, + }; + + // + // Combine the shader stages into a single shader + // + let shader = + device_context.create_shader(vec![mesh_shader_stage_def, frag_shader_stage_def])?; + + // + // Create the root signature object - it represents the pipeline layout and can be shared among + // shaders. But one per shader is fine. + // + let root_signature = device_context.create_root_signature(&RafxRootSignatureDef { + shaders: &[shader.clone()], + immutable_samplers: &[], + })?; + + let vertex_layout = RafxVertexLayout::default(); + + let pipeline = device_context.create_graphics_pipeline(&RafxGraphicsPipelineDef { + shader: &shader, + root_signature: &root_signature, + vertex_layout: &vertex_layout, + blend_state: &Default::default(), + depth_state: &Default::default(), + rasterizer_state: &Default::default(), + color_formats: &[swapchain_helper.format()], + sample_count: RafxSampleCount::SampleCount1, + depth_stencil_format: None, + primitive_topology: RafxPrimitiveTopology::TriangleList, + debug_name: None, + })?; + + let start_time = std::time::Instant::now(); + + // + // SDL2 window pumping + // + log::info!("Starting window event loop"); + let mut event_pump = sdl2_systems + .context + .event_pump() + .expect("Could not create sdl event pump"); + + 'running: loop { + if !process_input(&mut event_pump) { + break 'running; + } + + let elapsed_seconds = start_time.elapsed().as_secs_f32(); + + #[rustfmt::skip] + let vertex_data = [ + 0.0f32, 0.5, 1.0, 0.0, 0.0, + 0.5 - (elapsed_seconds.cos() / 2. + 0.5), -0.5, 0.0, 1.0, 0.0, + -0.5 + (elapsed_seconds.cos() / 2. + 0.5), -0.5, 0.0, 0.0, 1.0, + ]; + + let color = (elapsed_seconds.cos() + 1.0) / 2.0; + let uniform_data = [color, 0.0, 1.0 - color, 1.0]; + + // + // Acquire swapchain image + // + let (window_width, window_height) = sdl2_systems.window.vulkan_drawable_size(); + let presentable_frame = + swapchain_helper.acquire_next_image(window_width, window_height, None)?; + let swapchain_texture = presentable_frame.swapchain_texture(); + + // + // Use the command pool/buffer assigned to this frame + // + let cmd_pool = &mut command_pools[presentable_frame.rotating_frame_index()]; + let cmd_buffer = &command_buffers[presentable_frame.rotating_frame_index()]; + let vertex_buffer = &vertex_buffers[presentable_frame.rotating_frame_index()]; + let uniform_buffer = &uniform_buffers[presentable_frame.rotating_frame_index()]; + + // + // Update the buffers + // + vertex_buffer.copy_to_host_visible_buffer(&vertex_data)?; + uniform_buffer.copy_to_host_visible_buffer(&uniform_data)?; + + // + // Record the command buffer. For now just transition it between layouts + // + cmd_pool.reset_command_pool()?; + + cmd_buffer.begin()?; + // Put it into a layout where we can draw on it + cmd_buffer.cmd_resource_barrier( + &[], + &[RafxTextureBarrier::state_transition( + &swapchain_texture, + RafxResourceState::PRESENT, + RafxResourceState::RENDER_TARGET, + )], + )?; + + cmd_buffer.cmd_begin_render_pass( + &[RafxColorRenderTargetBinding { + texture: &swapchain_texture, + load_op: RafxLoadOp::Clear, + store_op: RafxStoreOp::Store, + array_slice: None, + mip_slice: None, + clear_value: RafxColorClearValue([0.2, 0.2, 0.2, 1.0]), + resolve_target: None, + resolve_store_op: RafxStoreOp::DontCare, + resolve_mip_slice: None, + resolve_array_slice: None, + }], + None, + )?; + + cmd_buffer.cmd_bind_pipeline(&pipeline)?; + + cmd_buffer.cmd_draw_mesh(1, 1, 1)?; + + cmd_buffer.cmd_end_render_pass()?; + + // Put it into a layout where we can present it + cmd_buffer.cmd_resource_barrier( + &[], + &[RafxTextureBarrier::state_transition( + &swapchain_texture, + RafxResourceState::RENDER_TARGET, + RafxResourceState::PRESENT, + )], + )?; + cmd_buffer.end()?; + + // + // Present the image + // + let result = presentable_frame.present(&graphics_queue, &[&cmd_buffer]); + result.unwrap(); + } + + // Wait for all GPU work to complete before destroying resources it is using + graphics_queue.wait_for_queue_idle()?; + } + + // Optional, but calling this verifies that all rafx objects/device contexts have been + // destroyed and where they were created. Good for finding unintended leaks! + api.destroy()?; + + Ok(()) +} + +pub struct Sdl2Systems { + pub context: sdl2::Sdl, + pub video_subsystem: sdl2::VideoSubsystem, + pub window: sdl2::video::Window, +} + +pub fn sdl2_init() -> Sdl2Systems { + // Setup SDL + let context = sdl2::init().expect("Failed to initialize sdl2"); + let video_subsystem = context + .video() + .expect("Failed to create sdl video subsystem"); + + // Create the window + let mut window_binding = video_subsystem.window("Rafx Example", WINDOW_WIDTH, WINDOW_HEIGHT); + + let window_builder = window_binding + .position_centered() + .allow_highdpi() + .resizable(); + + #[cfg(target_os = "macos")] + let window_builder = window_builder.metal_view(); + + let window = window_builder.build().expect("Failed to create window"); + + Sdl2Systems { + context, + video_subsystem, + window, + } +} + +fn process_input(event_pump: &mut sdl2::EventPump) -> bool { + use sdl2::event::Event; + use sdl2::keyboard::Keycode; + + for event in event_pump.poll_iter() { + //log::trace!("{:?}", event); + match event { + // + // Halt if the user requests to close the window + // + Event::Quit { .. } => return false, + + // + // Close if the escape key is hit + // + Event::KeyDown { + keycode: Some(keycode), + keymod: _modifiers, + .. + } => { + //log::trace!("Key Down {:?} {:?}", keycode, modifiers); + if keycode == Keycode::Escape { + return false; + } + } + + _ => {} + } + } + + true +} diff --git a/rafx/examples/meshshader_triangle/shaders/shaders.hlsl b/rafx/examples/meshshader_triangle/shaders/shaders.hlsl new file mode 100644 index 000000000..e9cd552b4 --- /dev/null +++ b/rafx/examples/meshshader_triangle/shaders/shaders.hlsl @@ -0,0 +1,51 @@ +#define MAX_MESHLET_SIZE 128 +#define GROUP_SIZE MAX_MESHLET_SIZE +#define ROOT_SIG "" + +struct VertexOut +{ + float4 PositionVS : SV_Position; +}; + +[RootSignature(ROOT_SIG)] +[NumThreads(GROUP_SIZE, 1, 1)] +[OutputTopology("triangle")] +void main_ms( + uint gtid : SV_GroupThreadID, + uint gid : SV_GroupID, + out indices uint3 tris[MAX_MESHLET_SIZE], + out vertices VertexOut verts[MAX_MESHLET_SIZE] +) +{ + verts[gtid].PositionVS = float4(0.0f, 0.0f, 0.0f, 1.0f); + + int vertex_count = 3; + int primitive_count = 1; + SetMeshOutputCounts(vertex_count, primitive_count); + if (gtid == 0) + { + tris[gtid] = uint3(0, 1, 2); + } + + if (gtid < vertex_count) + { + if (gtid == 0) + { + verts[gtid].PositionVS = float4(-1.0, -1.0, 0.0, 1.0f); + } + else if (gtid == 1) + { + verts[gtid].PositionVS = float4(0.0, 1.0, 0.0, 1.0f); + } + else if (gtid == 2) + { + verts[gtid].PositionVS = float4(1.0, -1.0, 0.0, 1.0f); + } + } +} + + +float4 main_ps(VertexOut input) : SV_TARGET +{ + return float4(0.1, 1.0, 0.1, 1.0); +} \ No newline at end of file diff --git a/rafx/examples/meshshader_triangle/shaders/shaders.metal b/rafx/examples/meshshader_triangle/shaders/shaders.metal new file mode 100644 index 000000000..0241526f4 --- /dev/null +++ b/rafx/examples/meshshader_triangle/shaders/shaders.metal @@ -0,0 +1,34 @@ +#include + +using namespace metal; + +struct VertexOut { + float4 position [[position]]; +}; + +using mesh_t = mesh; + +[[mesh]] void main_ms(mesh_t m, uint thread_index [[thread_position_in_threadgroup]]) { + VertexOut v; + + if (thread_index == 0) { + v.position = float4(0.0, 1.0, 0.0, 1.0); + m.set_vertex(0, v); + + m.set_index(0, 0); + m.set_index(1, 1); + m.set_index(2, 2); + + m.set_primitive_count(1); + } else if (thread_index == 1) { + v.position = float4(1.0, -1.0, 0.0, 1.0); + m.set_vertex(1, v); + } else if (thread_index == 2) { + v.position = float4(-1.0, -1.0, 0.0, 1.0); + m.set_vertex(2, v); + } +} + +fragment half4 main_ps() { + return half4(0.1, 1.0, 0.1, 1.0); +}