diff --git a/Cargo.toml b/Cargo.toml index b444e302..e2b3af6b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,26 +11,27 @@ repository = "https://github.com/egraphs-good/egg" version = "0.9.5" [dependencies] -env_logger = {version = "0.9.0", default-features = false} +env_logger = { version = "0.9.0", default-features = false } fxhash = "0.2.1" hashbrown = "0.12.1" indexmap = "1.8.1" instant = "0.1.12" log = "0.4.17" -smallvec = {version = "1.8.0", features = ["union", "const_generics"]} -symbol_table = {version = "0.2.0", features = ["global"]} +smallvec = { version = "1.8.0", features = ["union", "const_generics"] } +symbol_table = { version = "0.2.0", features = ["global"] } symbolic_expressions = "5.0.3" thiserror = "1.0.31" # for the lp feature -coin_cbc = {version = "0.1.6", optional = true} +coin_cbc = { version = "0.1.6", optional = true } # for the serde-1 feature -serde = {version = "1.0.137", features = ["derive"], optional = true} -vectorize = {version = "0.2.0", optional = true} +serde = { version = "1.0.137", features = ["derive"], optional = true } +vectorize = { version = "0.2.0", optional = true } # for the reports feature -serde_json = {version = "1.0.81", optional = true} +serde_json = { version = "1.0.81", optional = true } +saturating = "0.1.0" [dev-dependencies] ordered-float = "3.0.0" diff --git a/src/explain.rs b/src/explain.rs index 79692326..41dd7c3f 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -3,6 +3,7 @@ use crate::{ util::pretty_print, Analysis, EClass, EGraph, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language, Pattern, PatternAst, RecExpr, Rewrite, Subst, UnionFind, Var, }; +use saturating::Saturating; use std::cmp::Ordering; use std::collections::{BinaryHeap, VecDeque}; use std::fmt::{self, Debug, Display, Formatter}; @@ -10,6 +11,8 @@ use std::rc::Rc; use symbolic_expressions::Sexp; +type ProofCost = Saturating; + const CONGRUENCE_LIMIT: usize = 2; const GREEDY_NUM_ITERS: usize = 2; @@ -62,14 +65,14 @@ pub struct Explain { // the explanation. // Invariant: The distance is always <= the unoptimized distance // That is, less than or equal to the result of `distance_between` - shortest_explanation_memo: HashMap<(Id, Id), (usize, Id)>, + shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>, } #[derive(Default)] struct DistanceMemo { - parent_distance: Vec<(Id, usize)>, + parent_distance: Vec<(Id, ProofCost)>, common_ancestor: HashMap<(Id, Id), Id>, - tree_depth: HashMap, + tree_depth: HashMap, } /// Explanation trees are the compact representation showing @@ -233,12 +236,12 @@ impl Explanation { /// Get the size of this explanation tree in terms of the number of rewrites /// in the let-bound version of the tree. - pub fn get_tree_size(&self) -> usize { + pub fn get_tree_size(&self) -> ProofCost { let mut seen = Default::default(); let mut seen_adjacent = Default::default(); - let mut sum = 0; + let mut sum: ProofCost = Saturating(0); for e in self.explanation_trees.iter() { - sum += self.tree_size(&mut seen, &mut seen_adjacent, e); + sum = sum + self.tree_size(&mut seen, &mut seen_adjacent, e); } sum } @@ -248,21 +251,21 @@ impl Explanation { seen: &mut HashSet<*const TreeTerm>, seen_adjacent: &mut HashSet<(Id, Id)>, current: &Rc>, - ) -> usize { + ) -> ProofCost { if !seen.insert(&**current as *const TreeTerm) { - return 0; + return Saturating(0); } - let mut my_size = 0; + let mut my_size: ProofCost = Saturating(0); if current.forward_rule.is_some() { - my_size += 1; + my_size += Saturating(1); } if current.backward_rule.is_some() { - my_size += 1; + my_size += Saturating(1); } - assert!(my_size <= 1); - if my_size == 1 { + assert!(my_size <= Saturating(1)); + if my_size == Saturating(1) { if !seen_adjacent.insert((current.current, current.last)) { - return 0; + return Saturating(0); } else { seen_adjacent.insert((current.last, current.current)); } @@ -270,7 +273,7 @@ impl Explanation { for child_proof in ¤t.child_proofs { for child in child_proof { - my_size += self.tree_size(seen, seen_adjacent, child); + my_size = self.tree_size(seen, seen_adjacent, child); } } my_size @@ -853,7 +856,7 @@ impl FlatTerm { // Make sure to use push_increase instead of push when using priority queue #[derive(Copy, Clone, Eq, PartialEq)] struct HeapState { - cost: usize, + cost: ProofCost, item: I, } // The priority queue depends on `Ord`. @@ -1080,7 +1083,7 @@ impl Explain { return; } if let Some((cost, _)) = self.shortest_explanation_memo.get(&(node1, node2)) { - if cost <= &1 { + if cost <= &Saturating(1) { return; } } @@ -1106,9 +1109,9 @@ impl Explain { .neighbors .push(rconnection); self.shortest_explanation_memo - .insert((node1, node2), (1, node2)); + .insert((node1, node2), (Saturating(1), node2)); self.shortest_explanation_memo - .insert((node2, node1), (1, node1)); + .insert((node2, node1), (Saturating(1), node1)); } pub(crate) fn union( @@ -1132,9 +1135,9 @@ impl Explain { if let Justification::Rule(_) = justification { self.shortest_explanation_memo - .insert((node1, node2), (1, node2)); + .insert((node1, node2), (Saturating(1), node2)); self.shortest_explanation_memo - .insert((node2, node1), (1, node1)); + .insert((node2, node1), (Saturating(1), node1)); } let pconnection = Connection { @@ -1455,20 +1458,20 @@ impl Explain { enodes } - fn add_tree_depths(&self, node: Id, depths: &mut HashMap) -> usize { + fn add_tree_depths(&self, node: Id, depths: &mut HashMap) -> ProofCost { if depths.get(&node).is_none() { let parent = self.parent(node); let depth = if parent == node { - 0 + Saturating(0) } else { - self.add_tree_depths(parent, depths) + 1 + self.add_tree_depths(parent, depths) + Saturating(1) }; depths.insert(node, depth); } return *depths.get(&node).unwrap(); } - fn calculate_tree_depths(&self) -> HashMap { + fn calculate_tree_depths(&self) -> HashMap { let mut depths = HashMap::default(); for i in 0..self.explainfind.len() { self.add_tree_depths(Id::from(i), &mut depths); @@ -1476,7 +1479,7 @@ impl Explain { depths } - fn replace_distance(&mut self, current: Id, next: Id, right: Id, distance: usize) { + fn replace_distance(&mut self, current: Id, next: Id, right: Id, distance: ProofCost) { self.shortest_explanation_memo .insert((current, right), (distance, next)); } @@ -1486,11 +1489,10 @@ impl Explain { right: Id, left_connections: &[Connection], distance_memo: &mut DistanceMemo, - target_cost: usize, ) { self.shortest_explanation_memo - .insert((right, right), (0, right)); - let mut last_cost = 0; + .insert((right, right), (Saturating(0), right)); + let mut last_cost = Saturating(0); for connection in left_connections.iter().rev() { let next = connection.next; let current = connection.current; @@ -1503,12 +1505,16 @@ impl Explain { last_cost = dist + next_cost; self.replace_distance(current, next, right, next_cost + dist); } - assert!(last_cost <= target_cost); } - fn distance_between(&mut self, left: Id, right: Id, distance_memo: &mut DistanceMemo) -> usize { + fn distance_between( + &mut self, + left: Id, + right: Id, + distance_memo: &mut DistanceMemo, + ) -> ProofCost { if left == right { - return 0; + return Saturating(0); } let ancestor = if let Some(a) = distance_memo.common_ancestor.get(&(left, right)) { *a @@ -1535,11 +1541,13 @@ impl Explain { ); // calculate distance to find upper bound - match b.checked_add(c) { - Some(added) => added - .checked_sub(a.checked_mul(2).unwrap_or(0)) - .unwrap_or(usize::MAX), - None => usize::MAX, + match b.0.checked_add(c.0) { + Some(added) => Saturating( + added + .checked_sub(a.0.checked_mul(2).unwrap_or(0)) + .unwrap_or(usize::MAX), + ), + None => Saturating(usize::MAX), } //assert_eq!(dist+1, Explanation::new(self.explain_enodes(left, right, &mut Default::default())).make_flat_explanation().len()); @@ -1550,20 +1558,16 @@ impl Explain { current: Id, next: Id, distance_memo: &mut DistanceMemo, - ) -> usize { + ) -> ProofCost { let current_node = self.explainfind[usize::from(current)].node.clone(); let next_node = self.explainfind[usize::from(next)].node.clone(); - let mut cost: usize = 0; + let mut cost: ProofCost = Saturating(0); for (left_child, right_child) in current_node .children() .iter() .zip(next_node.children().iter()) { - cost = cost.saturating_add(self.distance_between( - *left_child, - *right_child, - distance_memo, - )); + cost += self.distance_between(*left_child, *right_child, distance_memo); } cost } @@ -1572,12 +1576,12 @@ impl Explain { &mut self, connection: &Connection, distance_memo: &mut DistanceMemo, - ) -> usize { + ) -> ProofCost { match connection.justification { Justification::Congruence => { self.congruence_distance(connection.current, connection.next, distance_memo) } - Justification::Rule(_) => 1, + Justification::Rule(_) => Saturating(1), } } @@ -1586,7 +1590,7 @@ impl Explain { enode: Id, ancestor: Id, distance_memo: &mut DistanceMemo, - ) -> usize { + ) -> ProofCost { loop { let parent = distance_memo.parent_distance[usize::from(enode)].0; let dist = distance_memo.parent_distance[usize::from(enode)].1; @@ -1596,8 +1600,7 @@ impl Explain { let parent_parent = distance_memo.parent_distance[usize::from(parent)].0; if parent_parent != parent { - let new_dist = - dist.saturating_add(distance_memo.parent_distance[usize::from(parent)].1); + let new_dist = dist + distance_memo.parent_distance[usize::from(parent)].1; distance_memo.parent_distance[usize::from(enode)] = (parent_parent, new_dist); } else { if ancestor == Id::from(usize::MAX) { @@ -1617,7 +1620,7 @@ impl Explain { Justification::Congruence => { self.congruence_distance(current, next, distance_memo) } - Justification::Rule(_) => 1, + Justification::Rule(_) => Saturating(1), }; distance_memo.parent_distance[usize::from(parent)] = (self.parent(parent), cost); } @@ -1703,7 +1706,7 @@ impl Explain { ) -> Option<(Vec, Vec)> { let mut todo = BinaryHeap::new(); todo.push(HeapState { - cost: 0, + cost: Saturating(0), item: Connection { current: start, next: start, @@ -1737,7 +1740,7 @@ impl Explain { for neighbor in &self.explainfind[usize::from(current)].neighbors { if let Justification::Rule(_) = neighbor.justification { - let neighbor_cost = cost_so_far.saturating_add(1); + let neighbor_cost = cost_so_far + Saturating(1); todo.push(HeapState { item: neighbor.clone(), cost: neighbor_cost, @@ -1748,7 +1751,7 @@ impl Explain { for other in congruence_neighbors[usize::from(current)].iter() { let next = other; let distance = self.congruence_distance(current, *next, distance_memo); - let next_cost = cost_so_far.saturating_add(distance); + let next_cost = cost_so_far + distance; todo.push(HeapState { item: Connection { current, @@ -1767,7 +1770,7 @@ impl Explain { let mut right_connections = vec![]; // we would like to assert that we found a path better than the normal one - // but since proof sizes are saturated (saturating_add) this is not true + // but since proof sizes are saturated this is not true /*let dist = self.distance_between(start, end, distance_memo); if *total_cost.unwrap() > dist { panic!( @@ -1776,7 +1779,7 @@ impl Explain { dist ); }*/ - if *total_cost.unwrap() == self.distance_between(start, end, distance_memo) { + if *total_cost.unwrap() >= self.distance_between(start, end, distance_memo) { let (a_left_connections, a_right_connections) = self.get_path_unoptimized(start, end); left_connections = a_left_connections; right_connections = a_right_connections; @@ -1793,12 +1796,7 @@ impl Explain { } } connections.reverse(); - self.populate_path_length( - end, - &connections, - distance_memo, - *path_cost.get(&end).unwrap(), - ); + self.populate_path_length(end, &connections, distance_memo); left_connections = connections; } @@ -1974,7 +1972,7 @@ impl Explain { ) { let mut congruence_neighbors = vec![vec![]; self.explainfind.len()]; self.find_congruence_neighbors::(classes, &mut congruence_neighbors, unionfind); - let mut parent_distance = vec![(Id::from(0), 0); self.explainfind.len()]; + let mut parent_distance = vec![(Id::from(0), Saturating(0)); self.explainfind.len()]; for (i, entry) in parent_distance.iter_mut().enumerate() { entry.0 = Id::from(i); }