From 329de885a7cc1669cc18cbda9a9847355821aade Mon Sep 17 00:00:00 2001 From: Caleb Date: Wed, 17 Jul 2024 09:48:09 -0400 Subject: [PATCH] check if this works --- calyx-opt/src/analysis/static_schedule.rs | 29 ++++- calyx-opt/src/passes/compile_static.rs | 148 ++++++++++++---------- 2 files changed, 100 insertions(+), 77 deletions(-) diff --git a/calyx-opt/src/analysis/static_schedule.rs b/calyx-opt/src/analysis/static_schedule.rs index fa4b3c7149..64c135bd70 100644 --- a/calyx-opt/src/analysis/static_schedule.rs +++ b/calyx-opt/src/analysis/static_schedule.rs @@ -435,8 +435,7 @@ impl Node { } /// Get max value of all nodes in the tree, according to some function f. - /// `f` takes in a Tree (i.e., a node type) and returns a `u64`. Note that this - /// function assumes a minimum value of 1. ** This is a weird assumption**. + /// `f` takes in a Tree (i.e., a node type) and returns a `u64`. pub fn get_max_value(&self, name: &ir::Id, f: &F) -> u64 where F: Fn(&SingleNode) -> u64, @@ -1039,6 +1038,8 @@ impl SingleNode { ); let fsm_identifier = match self.fsm_cell.as_ref() { + // If the tree does not have an fsm cell, then we can err on the + // side of giving it its own unique identifier. None => self.root.0, Some(fsm_rc) => fsm_rc.borrow().fsm_cell.borrow().name(), }; @@ -1595,8 +1596,10 @@ impl SingleNode { } // Adds conflicts between children and any descendents. - // And add conflicts between any overlapping children (XXX(Caleb): need to - // do this for dumb rzn.) + // Also add conflicts between any overlapping children. XXX(Caleb): normally + // there shouldn't be overlapping children, but when we are doing the traditional + // method in we don't offload (and therefore don't need this tree structure) + // I have created dummy trees for the sole purpose of drawing conflicts pub fn add_conflicts(&self, conflict_graph: &mut GraphColoring) { let root_name = self.root.0; for (child, _) in &self.children { @@ -1605,13 +1608,18 @@ impl SingleNode { } child.add_conflicts(conflict_graph); } + // Adding conflicts between overlapping children. for ((child_a, (beg_a, end_a)), (child_b, (beg_b, end_b))) in self.children.iter().tuple_combinations() { + // Checking if children overlap: either b begins within a, it + // ends within a, or it encompasses a's entire interval. if ((beg_a <= beg_b) & (beg_b < end_b)) | ((beg_a < end_b) & (end_b <= end_b)) | (beg_b <= beg_a && end_b >= end_a) { + // Adding conflicts between all nodes of the children if + // the children overlap. for a_node in child_a.get_all_nodes() { for b_node in child_b.get_all_nodes() { conflict_graph.insert_conflict(&a_node, &b_node); @@ -1626,7 +1634,7 @@ impl SingleNode { where F: Fn(&SingleNode) -> u64, { - let mut cur_max = 1; + let mut cur_max = 0; if self.root.0 == name { cur_max = std::cmp::max(cur_max, f(self)); } @@ -1762,11 +1770,18 @@ impl ParNodes { }), ); + let fsm_identifier = match longest_node.fsm_cell.as_ref() { + // If the tree does not have an fsm cell, then we can err on the + // side of giving it its own unique identifier. + None => longest_node.root.0, + Some(fsm_rc) => fsm_rc.borrow().fsm_cell.borrow().name(), + }; + let total_latency = self.latency * self.num_repeats; fsm_info_map.insert( early_reset_group.borrow().name(), ( - self.group_name, + fsm_identifier, self.query_between((0, 1), builder), self.query_between((total_latency - 1, total_latency), builder), ), @@ -1915,7 +1930,7 @@ impl ParNodes { where F: Fn(&SingleNode) -> u64, { - let mut cur_max = 1; + let mut cur_max = 0; for (thread, _) in &self.threads { cur_max = std::cmp::max(cur_max, thread.get_max_value(name, f)); } diff --git a/calyx-opt/src/passes/compile_static.rs b/calyx-opt/src/passes/compile_static.rs index 6bed0f1d8f..1bdb3618fa 100644 --- a/calyx-opt/src/passes/compile_static.rs +++ b/calyx-opt/src/passes/compile_static.rs @@ -23,13 +23,20 @@ pub struct CompileStatic { wrapper_map: HashMap, /// maps fsm names to their corresponding signal_reg signal_reg_map: HashMap, - /// maps reset_early_group names to (fsm == 0, final_fsm_state); + /// maps reset_early_group names to (fsm_identifier, fsm_first_state, final_fsm_state); + /// The "fsm identifier" is just the name of the fsm (if it exists) and + /// some other unique identifier if it doesn't exist (this works because + /// it is always fine to give each entry its own completely unique identifier.) fsm_info_map: HashMap, ir::Guard)>, /// Command line arguments: - /// cutoff for one hot encoding + /// Cutoff for one hot encoding. Anything larger than the cutoff becomes + /// binary. one_hot_cutoff: u64, + /// Bool indicating whether to make the FSM pause (i.e., stop counting) when + /// offloading. In order for compilation to make sense, this parameter must + /// match the parameter for `static-inline`. offload_pause: bool, } @@ -84,14 +91,14 @@ impl ConstructVisitor for CompileStatic { } impl CompileStatic { - /// Builds a wrapper group for group named group_name using fsm and + /// Builds a wrapper group for group named group_name using fsm_final_state /// and a signal_reg. - /// Both the group and FSM (and the signal_reg) should already exist. + /// We set the signal_reg high on the final fsm state, since we know the + /// `done` signal should be high the next cycle after that. /// `add_resetting_logic` is a bool; since the same FSM/signal_reg pairing /// may be used for multiple static islands, and we only add resetting logic /// for the signal_reg once. fn build_wrapper_group( - fsm_eq_0: ir::Guard, fsm_final_state: ir::Guard, group_name: &ir::Id, signal_reg: ir::RRC, @@ -110,7 +117,6 @@ impl CompileStatic { ) }); - // fsm.out == 0 structure!( builder; let signal_on = constant(1, 1); let signal_off = constant(0, 1); @@ -121,10 +127,9 @@ impl CompileStatic { guard!(signal_reg["out"]); // !signal_reg.out let not_signal_reg = signal_reg_guard.clone().not(); - // fsm.out == 0 & signal_reg.out - let eq_0_and_signal = fsm_eq_0.clone() & signal_reg_guard; - // fsm.out == 0 & ! signal_reg.out + // & ! signal_reg.out let final_state_not_signal = fsm_final_state & not_signal_reg; + // create the wrapper group for early_reset_group let mut wrapper_name = group_name.clone().to_string(); wrapper_name.insert_str(0, "wrapper_"); @@ -133,11 +138,11 @@ impl CompileStatic { builder; // early_reset_group[go] = 1'd1 early_reset_group["go"] = ? signal_on["out"]; - // when fsm == 0, and !signal_reg, then set signal_reg to high + // when and !signal_reg, then set signal_reg to high signal_reg["write_en"] = final_state_not_signal ? signal_on["out"]; signal_reg["in"] = final_state_not_signal ? signal_on["out"]; - // group[done] = fsm.out == 0 & signal_reg.out ? 1'd1 - g["done"] = eq_0_and_signal ? signal_on["out"]; + // group[done] = signal_reg.out ? 1'd1 + g["done"] = signal_reg_guard ? signal_on["out"]; ); if add_reseting_logic { // continuous assignments to reset signal_reg back to 0 when the wrapper is done @@ -145,8 +150,8 @@ impl CompileStatic { builder; // when (fsm == 0 & signal_reg is high), which is the done condition of the wrapper, // reset the signal_reg back to low - signal_reg["write_en"] = eq_0_and_signal ? signal_on["out"]; - signal_reg["in"] = eq_0_and_signal ? signal_off["out"]; + signal_reg["write_en"] = signal_reg_guard ? signal_on["out"]; + signal_reg["in"] = signal_reg_guard ? signal_off["out"]; ); builder.add_continuous_assignments(continuous_assigns.to_vec()); } @@ -228,7 +233,8 @@ impl CompileStatic { } } -// These are the functions used to allocate FSMs to static islands +// These are the functions used to allocate FSMs to static islands through a +// greedy coloring algorithm. impl CompileStatic { // Given a list of `static_groups`, find the group named `name`. // If there is no such group, then there is an unreachable! error. @@ -246,39 +252,9 @@ impl CompileStatic { ) } - // Gets all of the triggered static groups within `c`, and adds it to `cur_names`. - // Relies on sgroup_uses_map to take into account groups that are triggered through - // their `go` hole. - fn get_used_sgroups(c: &ir::Control, cur_names: &mut HashSet) { - match c { - ir::Control::Empty(_) - | ir::Control::Enable(_) - | ir::Control::Invoke(_) => (), - ir::Control::Static(sc) => { - let ir::StaticControl::Enable(s) = sc else { - unreachable!("Non-Enable Static Control should have been compiled away. Run {} to do this", crate::passes::StaticInliner::name()); - }; - let group_name = s.group.borrow().name(); - cur_names.insert(group_name); - } - ir::Control::Par(ir::Par { stmts, .. }) - | ir::Control::Seq(ir::Seq { stmts, .. }) => { - for stmt in stmts { - Self::get_used_sgroups(stmt, cur_names); - } - } - ir::Control::Repeat(ir::Repeat { body, .. }) - | ir::Control::While(ir::While { body, .. }) => { - Self::get_used_sgroups(body, cur_names); - } - ir::Control::If(if_stmt) => { - Self::get_used_sgroups(&if_stmt.tbranch, cur_names); - Self::get_used_sgroups(&if_stmt.fbranch, cur_names); - } - } - } - - /// XXX(Caleb): Todo. + /// Add conflicts between all nodes of `fsm_trees` which are executing + /// on separate threads of a dynamic `par` block. + /// This function adds conflicts between nodes of separate trees. fn add_par_conflicts( c: &ir::Control, fsm_trees: &Vec, @@ -311,32 +287,38 @@ impl CompileStatic { ); } ir::Control::Par(par) => { - // sgroup_conflict_vec is a vec of HashSets. - // Each entry of the vec corresponds to a par thread, and holds + // `sgroup_conflict_vec` is a vec of HashSets. + // Each item in the vec corresponds to a par thread, and holds // all of the groups executed in that thread. - let mut sgroup_conflict_vec = Vec::new(); + let mut sgroup_conflict_vec: Vec> = Vec::new(); for stmt in &par.stmts { - let mut used_sgroups = HashSet::new(); - Self::get_used_sgroups(stmt, &mut used_sgroups); - sgroup_conflict_vec.push(used_sgroups); + sgroup_conflict_vec.push(HashSet::from_iter( + Self::get_static_enables(stmt), + )); } - for (thread1_sgroups, thread2_sgroups) in + for (thread1_st_enables, thread2_st_enables) in sgroup_conflict_vec.iter().tuple_combinations() { - for static_enable1 in thread1_sgroups { - for static_enable2 in thread2_sgroups { + // For each static group g1 enabled in thread1 and static + // group g2 enabled in thread2 respectively, add a conflict + // each node in g1 and g2's corresponding trees. + for static_enable1 in thread1_st_enables { + for static_enable2 in thread2_st_enables { + // Getting tree1 let tree1 = fsm_trees .iter() .find(|tree| { tree.get_group_name() == static_enable1 }) .expect("couldn't find FSM tree"); + // Getting tree2 let tree2 = fsm_trees .iter() .find(|tree| { tree.get_group_name() == static_enable2 }) .expect("couldn't find tree"); + // Add conflict between each node in tree1 and tree2 for sgroup1 in tree1.get_all_nodes() { for sgroup2 in tree2.get_all_nodes() { conflict_graph @@ -354,6 +336,10 @@ impl CompileStatic { } } + // Gets the maximum number of repeats for the static group named + // `sgroup` among all trees in `tree_objects`. Most of the time, `sgroup` + // will only appear once but it is possible that the same group appears + // in more than one tree. fn get_max_num_repeats(sgroup: ir::Id, tree_objects: &Vec) -> u64 { let mut cur_max = 1; for tree in tree_objects { @@ -364,6 +350,11 @@ impl CompileStatic { } cur_max } + + // Gets the maximum number of repeats for the static group named + // `sgroup` among all trees in `tree_objects`. Most of the time, `sgroup` + // will only appear once but it is possible that the same group appears + // in more than one tree. fn get_max_num_states(sgroup: ir::Id, tree_objects: &Vec) -> u64 { let mut cur_max = 1; for tree in tree_objects { @@ -375,6 +366,11 @@ impl CompileStatic { cur_max } + /// Creates a graph (one node per item in `sgroup` where nodes are the `sgroup`'s + /// names). + /// Use `tree_objects` and `control` to draw conflicts between any two nodes + /// that could be executing in parallel, and returns a greedy coloring of the + /// graph. pub fn get_coloring( tree_objects: &Vec, sgroups: &[ir::RRC], @@ -383,11 +379,17 @@ impl CompileStatic { let mut conflict_graph: GraphColoring = GraphColoring::from(sgroups.iter().map(|g| g.borrow().name())); // Necessary conflicts to ensure correctness + + // Self::add_par_conflicts adds necessary conflicts between all nodes of + // trees that execute in separate threads of the same `par` block: this is + // adding conflicts between nodes of separate trees. Self::add_par_conflicts(control, tree_objects, &mut conflict_graph); for tree in tree_objects { + // tree.add_conflicts adds the necessary conflicts within nodes of + // same tree. tree.add_conflicts(&mut conflict_graph); } - // Optional conflicts to improve QoR + // Optional conflicts to ?potentially? improve QoR // for (sgroup1, sgroup2) in sgroups.iter().tuple_combinations() { // let max_num_states1 = // Self::get_max_num_states(sgroup1.borrow().name(), tree_objects); @@ -414,12 +416,16 @@ impl CompileStatic { conflict_graph.color_greedy(None, true) } + /// Given a coloring of group names, returns a Hashmap that maps: + /// colors -> (max num states for that colro, max num repeats for color). pub fn get_color_max_values( coloring: &HashMap, tree_objects: &Vec, ) -> HashMap { let mut colors_to_sgroups: HashMap> = HashMap::new(); + // "Reverse" the coloring: instead of maping group names->colors, + // map colors -> group names. for (group_name, color) in coloring { colors_to_sgroups .entry(*color) @@ -429,11 +435,13 @@ impl CompileStatic { colors_to_sgroups .into_iter() .map(|(name, colors_sgroups)| { + // Get max num states for this color let max_num_states = colors_sgroups .iter() .map(|gname| Self::get_max_num_states(*gname, tree_objects)) .max() .expect("color is empty"); + // Get max num repeats for this color let max_num_repeats = colors_sgroups .iter() .map(|gname| { @@ -448,10 +456,14 @@ impl CompileStatic { } impl CompileStatic { + /// `get_interval_from_guard` returns the interval found within guard `g`. + /// The tricky part is that sometimes there can be an implicit latency + /// `lat` that is not explicitly stated (i..e, as an %[i:j]) in the guard. + /// XXX(Caleb): why do we need to handle `or`'s? Like why would they even + /// appear? fn get_interval_from_guard( g: &ir::Guard, lat: u64, - id: &ir::Id, ) -> Option<(u64, u64)> { match g { calyx_ir::Guard::Info(static_timing_interval) => { @@ -463,8 +475,8 @@ impl CompileStatic { | calyx_ir::Guard::True => Some((0, lat)), calyx_ir::Guard::And(l, r) => { match ( - Self::get_interval_from_guard(l, lat, id), - Self::get_interval_from_guard(r, lat, id), + Self::get_interval_from_guard(l, lat), + Self::get_interval_from_guard(r, lat), ) { (None, Some(x)) | (Some(x), None) => Some(x), (None, None) => { @@ -480,7 +492,7 @@ impl CompileStatic { } } } - ir::Guard::Or(_, _) => None, + ir::Guard::Or(_, _) => panic!(""), } } @@ -510,7 +522,6 @@ impl CompileStatic { let x = Self::get_interval_from_guard( &assign.guard, target_group.borrow().get_latency(), - &name, ); let (beg, end) = x.expect("couldn't get interval from guard"); @@ -620,7 +631,6 @@ impl CompileStatic { let x = Self::get_interval_from_guard( &assign.guard, target_group.borrow().get_latency(), - &name, ); let (beg, end) = x.expect("couldn't get interval from guard"); @@ -727,7 +737,7 @@ impl CompileStatic { } ir::Control::Static(sc) => { let ir::StaticControl::Enable(s) = sc else { - panic!("") + unreachable!("Non-Enable Static Control should have been compiled away. Run {} to do this", crate::passes::StaticInliner::name()); }; vec![s.group.borrow().name()] } @@ -1139,7 +1149,7 @@ impl Visitor for CompileStatic { None => { // create the builder/cells that we need to create wrapper group let mut builder = ir::Builder::new(comp, sigs); - let (fsm_name, fsm_eq_0, fsm_final_state) = self.fsm_info_map.get(early_reset_name).unwrap_or_else(|| unreachable!("group {} has no correspondoing fsm in self.fsm_map", early_reset_name)); + let (fsm_name, _, fsm_final_state) = self.fsm_info_map.get(early_reset_name).unwrap_or_else(|| unreachable!("group {} has no correspondoing fsm in self.fsm_map", early_reset_name)); // If we've already made a wrapper for a group that uses the same // FSM, we can reuse the signal_reg. Otherwise, we must // instantiate a new signal_reg. @@ -1153,7 +1163,6 @@ impl Visitor for CompileStatic { self.signal_reg_map .insert(*fsm_name, signal_reg.borrow().name()); Self::build_wrapper_group( - fsm_eq_0.clone(), fsm_final_state.clone(), early_reset_name, signal_reg, @@ -1172,7 +1181,6 @@ impl Visitor for CompileStatic { unreachable!("signal reg {reg_name} found") }); Self::build_wrapper_group( - fsm_eq_0.clone(), fsm_final_state.clone(), early_reset_name, signal_reg, @@ -1245,9 +1253,9 @@ impl Visitor for CompileStatic { let reset_group_name = self.get_reset_group_name(sc); // Get fsm for reset_group - let (_, fsm_eq_0, _) = self.fsm_info_map.get(reset_group_name).unwrap_or_else(|| unreachable!("group {} has no correspondoing fsm in self.fsm_map", reset_group_name)); + let (_, fsm_first_state, _) = self.fsm_info_map.get(reset_group_name).unwrap_or_else(|| unreachable!("group {} has no correspondoing fsm in self.fsm_map", reset_group_name)); let wrapper_group = self.build_wrapper_group_while( - fsm_eq_0.clone(), + fsm_first_state.clone(), reset_group_name, Rc::clone(&s.port), &mut builder,