diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index c8ec854504..9da4873957 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -35,8 +35,12 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { // Drop all the functions we'll be inlining. (This also means we won't waste time processing // inlines in functions that will get inlined) let mut dropped_ids = FxHashSet::default(); + let mut inlined_dont_inlines = Vec::new(); module.functions.retain(|f| { if should_inline(&disallowed_argument_types, &disallowed_return_types, f) { + if has_dont_inline(f) { + inlined_dont_inlines.push(f.def_id().unwrap()); + } // TODO: We should insert all defined IDs in this function. dropped_ids.insert(f.def_id().unwrap()); false @@ -44,6 +48,16 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { true } }); + if !inlined_dont_inlines.is_empty() { + let names = get_names(module); + for f in inlined_dont_inlines { + sess.warn(&format!( + "function `{}` has `dont_inline` attribute, but need to be inlined because it has illegal argument or return types", + get_name(&names, f) + )); + } + } + // Drop OpName etc. for inlined functions module.debug_names.retain(|inst| { !inst.operands.iter().any(|op| { @@ -204,6 +218,12 @@ fn compute_disallowed_argument_and_return_types( (disallowed_argument_types, disallowed_return_types) } +fn has_dont_inline(function: &Function) -> bool { + let def = function.def.as_ref().unwrap(); + let control = def.operands[0].unwrap_function_control(); + control.contains(FunctionControl::DONT_INLINE) +} + fn should_inline( disallowed_argument_types: &FxHashSet, disallowed_return_types: &FxHashSet, diff --git a/crates/rustc_codegen_spirv/src/linker/inline_globals.rs b/crates/rustc_codegen_spirv/src/linker/inline_globals.rs new file mode 100644 index 0000000000..ab712ba11e --- /dev/null +++ b/crates/rustc_codegen_spirv/src/linker/inline_globals.rs @@ -0,0 +1,370 @@ +use rspirv::dr::{Instruction, Module, Operand}; +use rspirv::spirv::Op; +use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_session::Session; + +#[derive(Debug, Clone, PartialEq)] +struct NormalizedInstructions { + vars: Vec, + insts: Vec, + root: u32, +} + +impl NormalizedInstructions { + fn new(id: u32) -> Self { + NormalizedInstructions { + vars: Vec::new(), + insts: Vec::new(), + root: id, + } + } + + fn extend(&mut self, o: NormalizedInstructions) { + self.vars.extend(o.vars); + self.insts.extend(o.insts); + } + + fn is_empty(&self) -> bool { + self.insts.is_empty() && self.vars.is_empty() + } + + fn fix_ids(&mut self, bound: &mut u32, new_root: u32) { + let mut id_map: FxHashMap = FxHashMap::default(); + id_map.insert(self.root, new_root); + for inst in &mut self.vars { + Self::fix_instruction(self.root, inst, &mut id_map, bound, new_root); + } + for inst in &mut self.insts { + Self::fix_instruction(self.root, inst, &mut id_map, bound, new_root); + } + self.root = new_root; + } + + fn fix_instruction( + root: u32, + inst: &mut Instruction, + id_map: &mut FxHashMap, + bound: &mut u32, + new_root: u32, + ) { + for op in &mut inst.operands { + match op { + Operand::IdRef(id) => match id_map.get(id) { + Some(new_id) => { + *id = *new_id; + } + _ => {} + }, + _ => {} + } + } + if let Some(id) = &mut inst.result_id { + if *id != root { + id_map.insert(*id, *bound); + *id = *bound; + *bound += 1; + } else { + *id = new_root; + } + } + } +} + +#[derive(Debug, Clone, PartialEq)] +enum FunctionArg { + Invalid, + Insts(NormalizedInstructions), +} + +pub fn inline_global_varaibles(sess: &Session, module: &mut Module) -> super::Result<()> { + let mut cont = true; + let mut has_run = false; + //let mut i = 0; + //std::fs::write("res0.txt", module.disassemble()); + while cont { + cont = inline_global_varaibles_rec(module)?; + has_run = has_run || cont; + // i += 1; + //std::fs::write(format!("res{}.txt", i), module.disassemble()); + } + // needed because inline global create duplicate types... + if has_run { + let _timer = sess.timer("link_remove_duplicate_types_round_2"); + super::duplicates::remove_duplicate_types(module); + } + Ok(()) +} + +fn inline_global_varaibles_rec(module: &mut Module) -> super::Result { + // first collect global stuff + let mut variables: FxHashSet = FxHashSet::default(); + let mut function_types: FxHashMap = FxHashMap::default(); + for global_inst in &module.types_global_values { + let opcode = global_inst.class.opcode; + if opcode == Op::Variable || opcode == Op::Constant { + variables.insert(global_inst.result_id.unwrap()); + } else if opcode == Op::TypeFunction { + function_types.insert(global_inst.result_id.unwrap(), global_inst.clone()); + } + } + // then we keep track of which function parameter are always called with the same expression that only uses global variables + let mut function_args: FxHashMap<(u32, u32), FunctionArg> = FxHashMap::default(); + let mut bound = module.header.as_ref().unwrap().bound; + for caller in &module.functions { + let mut insts: FxHashMap = FxHashMap::default(); + // for variables that only stored once and it's stored as a ref + let mut ref_stores: FxHashMap> = FxHashMap::default(); + for block in &caller.blocks { + for inst in &block.instructions { + if inst.result_id.is_some() { + insts.insert(inst.result_id.unwrap(), inst.clone()); + } + if inst.class.opcode == Op::Store { + if let Operand::IdRef(to) = inst.operands[0] { + if let Operand::IdRef(from) = inst.operands[1] { + match ref_stores.get(&to) { + None => { + ref_stores.insert(to, Some(from)); + } + Some(_) => { + ref_stores.insert(to, None); + } + } + } + } + } else if inst.class.opcode == Op::FunctionCall { + let function_id = match &inst.operands[0] { + &Operand::IdRef(w) => w, + _ => panic!(), + }; + for i in 1..inst.operands.len() { + let key = (function_id, i as u32 - 1); + // default to invalid to avoid duplicated code + let mut is_invalid = true; + match &inst.operands[i] { + &Operand::IdRef(w) => match &function_args.get(&key) { + None => { + match get_const_arg_insts( + bound, + &variables, + &insts, + &ref_stores, + w, + ) { + Some(insts) => { + is_invalid = false; + function_args.insert(key, FunctionArg::Insts(insts)); + } + None => {} + } + } + Some(FunctionArg::Insts(w2)) => { + let new_insts = get_const_arg_insts( + bound, + &variables, + &insts, + &ref_stores, + w, + ); + match new_insts { + Some(new_insts) => { + is_invalid = new_insts != *w2; + } + None => {} + } + } + _ => {} + }, + _ => {} + }; + if is_invalid { + function_args.insert(key, FunctionArg::Invalid); + } + } + } + } + } + } + // retain ones can rewrite + function_args.retain(|_, k| match k { + FunctionArg::Invalid => false, + FunctionArg::Insts(v) => !v.is_empty(), + }); + if function_args.is_empty() { + return Ok(false); + } + // start rewrite + for function in &mut module.functions { + let def = function.def.as_mut().unwrap(); + let fid = def.result_id.unwrap(); + let mut insts = NormalizedInstructions::new(0); + let mut j: u32 = 0; + let mut i = 0; + let mut removed_indexes: Vec = Vec::new(); + // callee side. remove parameters from function def + while i < function.parameters.len() { + let mut removed = false; + match &function_args.get(&(fid, j)) { + Some(FunctionArg::Insts(arg)) => { + let parameter = function.parameters.remove(i); + let mut arg = arg.clone(); + arg.fix_ids(&mut bound, parameter.result_id.unwrap()); + insts.extend(arg); + removed_indexes.push(j); + removed = true; + } + _ => (), + } + if !removed { + i += 1; + } + j += 1; + } + // callee side. and add a new function type in global section + if removed_indexes.len() > 0 { + if let Operand::IdRef(tid) = def.operands[1] { + let mut function_type: Instruction = function_types.get(&tid).unwrap().clone(); + let tid: u32 = bound; + bound += 1; + for i in removed_indexes.iter().rev() { + let i = *i as usize + 1; + function_type.operands.remove(i); + } + function_type.result_id = Some(tid); + def.operands[1] = Operand::IdRef(tid); + module.types_global_values.push(function_type); + } + } + // callee side. insert initialization instructions, which reuse the ids of the removed parameters + if !function.blocks.is_empty() { + let first_block = &mut function.blocks[0]; + first_block.instructions.splice(0..0, insts.vars); + // skip some instructions that must be at top of block + let mut i = 0; + loop { + if i >= first_block.instructions.len() { + break; + } + let inst = &first_block.instructions[i]; + if inst.class.opcode == Op::Label || inst.class.opcode == Op::Variable { + } else { + break; + } + i += 1; + } + first_block.instructions.splice(i..i, insts.insts); + } + // caller side, remove parameters from function call + for block in &mut function.blocks { + for inst in &mut block.instructions { + if inst.class.opcode == Op::FunctionCall { + let function_id = match &inst.operands[0] { + &Operand::IdRef(w) => w, + _ => panic!(), + }; + let mut removed_size = 0; + for i in 0..inst.operands.len() - 1 { + if function_args.contains_key(&(function_id, i as u32)) { + inst.operands.remove(i - removed_size + 1); + removed_size += 1; + } + } + } + } + } + } + if let Some(header) = &mut module.header { + header.bound = bound; + } + Ok(true) +} + +fn get_const_arg_operands( + variables: &FxHashSet, + insts: &FxHashMap, + ref_stores: &FxHashMap>, + operand: &Operand, +) -> Option { + match operand { + Operand::IdRef(id) => { + let insts = get_const_arg_insts_rec(variables, insts, ref_stores, *id)?; + return Some(insts); + } + Operand::LiteralInt32(_) => {} + Operand::LiteralInt64(_) => {} + Operand::LiteralFloat32(_) => {} + Operand::LiteralFloat64(_) => {} + Operand::LiteralExtInstInteger(_) => {} + Operand::LiteralSpecConstantOpInteger(_) => {} + Operand::LiteralString(_) => {} + _ => { + // TOOD add more cases + return None; + } + } + return Some(NormalizedInstructions::new(0)); +} + +fn get_const_arg_insts( + mut bound: u32, + variables: &FxHashSet, + insts: &FxHashMap, + ref_stores: &FxHashMap>, + id: u32, +) -> Option { + let mut res = get_const_arg_insts_rec(variables, insts, ref_stores, id)?; + res.insts.reverse(); + // the bound passed in is always the same + // we need to normalize the ids, so they are the same when compared + let fake_root = bound; + bound += 1; + res.fix_ids(&mut bound, fake_root); + Some(res) +} + +fn get_const_arg_insts_rec( + variables: &FxHashSet, + insts: &FxHashMap, + ref_stores: &FxHashMap>, + id: u32, +) -> Option { + let mut result = NormalizedInstructions::new(id); + if variables.contains(&id) { + return Some(result); + } + let par: &Instruction = insts.get(&id)?; + if par.class.opcode == Op::AccessChain { + result.insts.push(par.clone()); + for oprand in &par.operands { + let insts = get_const_arg_operands(variables, insts, ref_stores, oprand)?; + result.extend(insts); + } + } else if par.class.opcode == Op::FunctionCall { + result.insts.push(par.clone()); + // skip first, first is function id + for oprand in &par.operands[1..] { + let insts = get_const_arg_operands(variables, insts, ref_stores, oprand)?; + result.extend(insts); + } + } else if par.class.opcode == Op::Variable { + result.vars.push(par.clone()); + let stored = ref_stores.get(&id)?; + let stored = (*stored)?; + result.insts.push(Instruction::new( + Op::Store, + None, + None, + vec![Operand::IdRef(id), Operand::IdRef(stored)], + )); + let new_insts = get_const_arg_insts_rec(variables, insts, ref_stores, stored)?; + result.extend(new_insts); + } else if par.class.opcode == Op::ArrayLength { + result.insts.push(par.clone()); + let insts = get_const_arg_operands(variables, insts, ref_stores, &par.operands[0])?; + result.extend(insts); + } else { + // TOOD add more cases + return None; + } + Some(result) +} diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index 5df298dd5d..d9f896ac0b 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -7,6 +7,7 @@ mod duplicates; mod entry_interface; mod import_export_link; mod inline; +mod inline_globals; mod ipo; mod mem2reg; mod param_weakening; @@ -216,6 +217,11 @@ pub fn link(sess: &Session, mut inputs: Vec, opts: &Options) -> Result f32{ + (con[1] - con[0]) as f32 +} + +#[spirv(fragment)] +pub fn main( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] runtime_array: &mut [u32], +) { + sdf(runtime_array); +}