diff --git a/engine/src/gfx.rs b/engine/src/gfx.rs index 638bda6b..660f42d9 100644 --- a/engine/src/gfx.rs +++ b/engine/src/gfx.rs @@ -57,6 +57,7 @@ pub struct GfxContext { pub sun_shadowmap: Texture, pub pbr: PBR, pub lamplights: LampLights, + pub defines: FastMap, pub simplelit_bg: wgpu::BindGroup, pub bnoise_bg: wgpu::BindGroup, @@ -270,6 +271,7 @@ impl GfxContext { device, queue, pbr, + defines: Default::default(), }; me.update_simplelit_bg(); @@ -304,6 +306,17 @@ impl GfxContext { self.materials.get_mut(id).zip(Some(&self.queue)) } + pub fn set_define_flag(&mut self, name: &str, inserted: bool) { + if self.defines.contains_key(name) == inserted { + return; + } + if inserted { + self.defines.insert(name.to_string(), String::new()); + } else { + self.defines.remove(name); + } + } + pub fn mk_shadowmap(device: &Device, res: u32) -> Texture { let format = TextureFormat::Depth32Float; let extent = wgpu::Extent3d { @@ -452,6 +465,7 @@ impl GfxContext { &mut p.shader_watcher, &self.device, name, + &self.defines, ) } @@ -648,7 +662,7 @@ impl GfxContext { self.pipelines .try_borrow_mut() .unwrap() - .check_shader_updates(&self.device); + .check_shader_updates(&self.defines, &self.device); } self.tick += 1; } diff --git a/engine/src/pbr.rs b/engine/src/pbr.rs index c8b5f58b..95298983 100644 --- a/engine/src/pbr.rs +++ b/engine/src/pbr.rs @@ -2,6 +2,7 @@ use crate::{ compile_shader, CompiledModule, GfxContext, PipelineBuilder, Texture, TextureBuilder, Uniform, TL, }; +use common::FastMap; use geom::{Vec3, Vec4}; use wgpu::{ BindGroup, BlendState, CommandEncoder, CommandEncoderDescriptor, Device, FragmentState, LoadOp, @@ -255,7 +256,8 @@ impl PBR { push_constant_ranges: &[], }); - let brdf_convolution_module = compile_shader(device, "pbr/brdf_convolution"); + let brdf_convolution_module = + compile_shader(device, "pbr/brdf_convolution", &FastMap::default()); let cubemapline = device.create_render_pipeline(&RenderPipelineDescriptor { label: None, diff --git a/engine/src/pipelines.rs b/engine/src/pipelines.rs index 2be0532f..cf5c2e63 100644 --- a/engine/src/pipelines.rs +++ b/engine/src/pipelines.rs @@ -29,7 +29,10 @@ pub struct Pipelines { impl Pipelines { pub fn new(device: &Device) -> Pipelines { let mut shader_cache = FastMap::default(); - shader_cache.insert("mipmap".to_string(), compile_shader(device, "mipmap")); + shader_cache.insert( + "mipmap".to_string(), + compile_shader(device, "mipmap", &FastMap::default()), + ); Pipelines { shader_cache, @@ -44,6 +47,7 @@ impl Pipelines { shader_watcher: &mut FastMap, Option)>, device: &Device, name: &str, + defines: &FastMap, ) -> CompiledModule { if let Some(v) = shader_cache.get(name) { return v.clone(); @@ -51,7 +55,7 @@ impl Pipelines { shader_cache .entry(name.to_string()) .or_insert_with_key(move |key| { - let module = compile_shader(device, key); + let module = compile_shader(device, key, defines); for dep in module.get_deps() { shader_watcher @@ -89,6 +93,7 @@ impl Pipelines { &mut self.shader_watcher, device, name, + &gfx.defines, ) }); for dep in deps { @@ -100,10 +105,15 @@ impl Pipelines { } } - pub fn invalidate(&mut self, device: &Device, shader_name: &str) { + pub fn invalidate( + &mut self, + defines: &FastMap, + device: &Device, + shader_name: &str, + ) { if let Some(x) = self.shader_cache.get_mut(shader_name) { device.push_error_scope(ErrorFilter::Validation); - let new_shader = compile_shader(device, shader_name); + let new_shader = compile_shader(device, shader_name, defines); let scope = beul::execute(device.pop_error_scope()); if scope.is_some() { log::error!("failed to compile shader for invalidation {}", shader_name); @@ -123,7 +133,7 @@ impl Pipelines { } } - pub fn check_shader_updates(&mut self, device: &Device) { + pub fn check_shader_updates(&mut self, defines: &FastMap, device: &Device) { let mut to_invalidate = HashSet::new(); for (sname, (parents, entry)) in &mut self.shader_watcher { let meta = unwrap_cont!(std::fs::metadata(Path::new(&format!( @@ -146,7 +156,7 @@ impl Pipelines { } for sname in to_invalidate { log::info!("invalidating shader {}", sname); - self.invalidate(device, &sname); + self.invalidate(defines, device, &sname); } } } diff --git a/engine/src/shader.rs b/engine/src/shader.rs index 6fb7d18c..07377cff 100644 --- a/engine/src/shader.rs +++ b/engine/src/shader.rs @@ -1,4 +1,5 @@ use crate::wgpu::ShaderSource; +use common::FastMap; use std::borrow::Cow; use std::ops::Deref; use std::path::{Path, PathBuf}; @@ -33,7 +34,11 @@ fn mk_module(data: String, device: &Device) -> ShaderModule { } /// if type isn't provided it will be detected by looking at extension -pub fn compile_shader(device: &Device, name: &str) -> CompiledModule { +pub fn compile_shader( + device: &Device, + name: &str, + defines: &FastMap, +) -> CompiledModule { let t = Instant::now(); defer!(log::info!( "compiling shader {} took {:?}", @@ -55,14 +60,70 @@ pub fn compile_shader(device: &Device, name: &str) -> CompiledModule { .unwrap(); let mut deps = vec![]; - source = replace_imports(&p, source, &mut deps); + source = replace_imports(&p, &source, &mut deps); + source = apply_ifdefs(defines, &source); let wgsl = mk_module(source, device); CompiledModule(Rc::new((wgsl, deps))) } -fn replace_imports(base: &Path, src: String, deps: &mut Vec) -> String { +/// apply_ifdefs updates the source taking into account #ifdef and #ifndef +/// syntax is as follow: +/// #ifdef or #ifndef +/// +/// #else or #elif +/// +/// #endif +/// +/// the ifdefs can be nested +fn apply_ifdefs(defines: &FastMap, src: &str) -> String { + // A stack of: + // whether that nest level is true + // whether we've seen a true yet in the if/elif chain + let mut ifdef_stack: Vec<(bool, bool)> = vec![]; + + src.lines() + .map(|line| { + let x = line.trim(); + if let Some(mut ifdef) = x.strip_prefix("#ifdef ") { + ifdef = ifdef.trim(); + let should_execute = defines.contains_key(ifdef); + ifdef_stack.push((should_execute, should_execute)); + return ""; + } + if let Some(mut ifndef) = x.strip_prefix("#ifndef ") { + ifndef = ifndef.trim(); + let should_execute = !defines.contains_key(ifndef); + ifdef_stack.push((should_execute, should_execute)); + return ""; + } + if let Some(_) = x.strip_prefix("#else") { + let (val, has_true) = ifdef_stack.last_mut().unwrap(); + *val = !*val && !*has_true; + return ""; + } + if let Some(mut elif) = x.strip_prefix("#elifdef ") { + elif = elif.trim(); + let (val, has_true) = ifdef_stack.last_mut().unwrap(); + *val = !*val && defines.contains_key(elif); + *has_true = *has_true || *val; + return ""; + } + if let Some(_) = x.strip_prefix("#endif") { + ifdef_stack.pop(); + return ""; + } + if ifdef_stack.iter().any(|(val, _)| !*val) { + return ""; + } + line + }) + .collect::>() + .join("\n") +} + +fn replace_imports(base: &Path, src: &str, deps: &mut Vec) -> String { src.lines() .map(move |x| { if let Some(mut loc) = x.strip_prefix("#include \"") { @@ -73,7 +134,7 @@ fn replace_imports(base: &Path, src: String, deps: &mut Vec) -> String { p.push(loc); let mut s = std::fs::read_to_string(p) .unwrap_or_else(|_| panic!("could not find included file {loc}")); - s = replace_imports(base, s, deps); + s = replace_imports(base, &s, deps); return Cow::Owned(s); } Cow::Borrowed(x) @@ -81,3 +142,60 @@ fn replace_imports(base: &Path, src: String, deps: &mut Vec) -> String { .collect::>() .join("\n") } + +#[cfg(test)] +mod tests { + use common::FastMap; + + #[test] + fn test_apply_ifdefs() { + let src = r#" + #ifdef A + a + #else + b + #endif + "#; + + fn f(x: &[&str]) -> FastMap { + x.iter().map(|x| (x.to_string(), "".to_string())).collect() + } + + assert_eq!(super::apply_ifdefs(&f(&[]), &src).trim(), "b"); + assert_eq!(super::apply_ifdefs(&f(&["A"]), &src).trim(), "a"); + assert_eq!(super::apply_ifdefs(&f(&["B"]), &src).trim(), "b"); + assert_eq!(super::apply_ifdefs(&f(&["A", "B"]), &src).trim(), "a"); + + let src = r#" + #ifdef A + a + #elifdef B + b + #else + c + #endif + "#; + + assert_eq!(super::apply_ifdefs(&f(&[]), &src).trim(), "c"); + assert_eq!(super::apply_ifdefs(&f(&["A"]), &src).trim(), "a"); + assert_eq!(super::apply_ifdefs(&f(&["B"]), &src).trim(), "b"); + assert_eq!(super::apply_ifdefs(&f(&["A", "B"]), &src).trim(), "a"); + + let src = r#" + #ifdef A + #ifdef B + a + #else + b + #endif + #else + c + #endif + "#; + + assert_eq!(super::apply_ifdefs(&f(&[]), &src).trim(), "c"); + assert_eq!(super::apply_ifdefs(&f(&["A"]), &src).trim(), "b"); + assert_eq!(super::apply_ifdefs(&f(&["B"]), &src).trim(), "c"); + assert_eq!(super::apply_ifdefs(&f(&["A", "B"]), &src).trim(), "a"); + } +}