diff --git a/Cargo.toml b/Cargo.toml index e2b3af6b..427c5de4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,8 @@ smallvec = { version = "1.8.0", features = ["union", "const_generics"] } symbol_table = { version = "0.2.0", features = ["global"] } symbolic_expressions = "5.0.3" thiserror = "1.0.31" +num-bigint = "0.4" +num-traits = "0.2" # for the lp feature coin_cbc = { version = "0.1.6", optional = true } diff --git a/rust-toolchain b/rust-toolchain index 2fef84a8..64d00e7d 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.60 \ No newline at end of file +1.60 diff --git a/src/explain.rs b/src/explain.rs index 9de2a17e..c42eb635 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -3,7 +3,7 @@ use crate::{ util::pretty_print, Analysis, EClass, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language, PatternAst, RecExpr, Rewrite, UnionFind, Var, }; -use saturating::Saturating; + use std::cmp::Ordering; use std::collections::{BinaryHeap, VecDeque}; use std::fmt::{self, Debug, Display, Formatter}; @@ -12,7 +12,10 @@ use std::rc::Rc; use symbolic_expressions::Sexp; -type ProofCost = Saturating; +use num_bigint::BigUint; +use num_traits::identities::{One, Zero}; + +type ProofCost = BigUint; const CONGRUENCE_LIMIT: usize = 2; const GREEDY_NUM_ITERS: usize = 2; @@ -252,7 +255,7 @@ impl Explanation { pub fn get_tree_size(&self) -> ProofCost { let mut seen = Default::default(); let mut seen_adjacent = Default::default(); - let mut sum: ProofCost = Saturating(0); + let mut sum: ProofCost = BigUint::zero(); for e in self.explanation_trees.iter() { sum += self.tree_size(&mut seen, &mut seen_adjacent, e); } @@ -266,19 +269,19 @@ impl Explanation { current: &Rc>, ) -> ProofCost { if !seen.insert(&**current as *const TreeTerm) { - return Saturating(0); + return BigUint::zero(); } - let mut my_size: ProofCost = Saturating(0); + let mut my_size: ProofCost = BigUint::zero(); if current.forward_rule.is_some() { - my_size += Saturating(1); + my_size += 1_u32; } if current.backward_rule.is_some() { - my_size += Saturating(1); + my_size += 1_u32; } - assert!(my_size <= Saturating(1)); - if my_size == Saturating(1) { + assert!(my_size.is_zero() || my_size.is_one()); + if my_size.is_one() { if !seen_adjacent.insert((current.current, current.last)) { - return Saturating(0); + return BigUint::zero(); } else { seen_adjacent.insert((current.last, current.current)); } @@ -867,7 +870,7 @@ impl FlatTerm { } // Make sure to use push_increase instead of push when using priority queue -#[derive(Copy, Clone, Eq, PartialEq)] +#[derive(Clone, Eq, PartialEq)] struct HeapState { cost: ProofCost, item: I, @@ -954,7 +957,7 @@ impl Explain { return; } if let Some((cost, _)) = self.shortest_explanation_memo.get(&(node1, node2)) { - if cost <= &Saturating(1) { + if cost.is_zero() || cost.is_one() { return; } } @@ -980,9 +983,9 @@ impl Explain { .neighbors .push(rconnection); self.shortest_explanation_memo - .insert((node1, node2), (Saturating(1), node2)); + .insert((node1, node2), (BigUint::one(), node2)); self.shortest_explanation_memo - .insert((node2, node1), (Saturating(1), node1)); + .insert((node2, node1), (BigUint::one(), node1)); } pub(crate) fn union( @@ -1004,9 +1007,9 @@ impl Explain { if let Justification::Rule(_) = justification { self.shortest_explanation_memo - .insert((node1, node2), (Saturating(1), node2)); + .insert((node1, node2), (BigUint::one(), node2)); self.shortest_explanation_memo - .insert((node2, node1), (Saturating(1), node1)); + .insert((node2, node1), (BigUint::one(), node1)); } let pconnection = Connection { @@ -1316,7 +1319,6 @@ impl<'x, L: Language> ExplainNodes<'x, L> { use_unoptimized: bool, ) -> TreeExplanation { let mut proof = vec![self.node_to_explanation(left, node_explanation_cache)]; - let (left_connections, right_connections) = if use_unoptimized { self.get_path_unoptimized(left, right) } else { @@ -1420,13 +1422,15 @@ impl<'x, L: Language> ExplainNodes<'x, L> { if depths.get(&node).is_none() { let parent = self.parent(node); let depth = if parent == node { - Saturating(0) + BigUint::zero() } else { - self.add_tree_depths(parent, depths) + Saturating(1) + self.add_tree_depths(parent, depths) + 1_u32 }; + depths.insert(node, depth); } - return *depths.get(&node).unwrap(); + + depths.get(&node).unwrap().clone() } fn calculate_tree_depths(&self) -> HashMap { @@ -1449,7 +1453,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { distance_memo: &mut DistanceMemo, ) { self.shortest_explanation_memo - .insert((right, right), (Saturating(0), right)); + .insert((right, right), (BigUint::zero(), right)); for connection in left_connections.iter().rev() { let next = connection.next; let current = connection.current; @@ -1457,7 +1461,8 @@ impl<'x, L: Language> ExplainNodes<'x, L> { .shortest_explanation_memo .get(&(next, right)) .unwrap() - .0; + .0 + .clone(); let dist = self.connection_distance(connection, distance_memo); self.replace_distance(current, next, right, next_cost + dist); } @@ -1470,7 +1475,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { distance_memo: &mut DistanceMemo, ) -> ProofCost { if left == right { - return Saturating(0); + return BigUint::zero(); } let ancestor = if let Some(a) = distance_memo.common_ancestor.get(&(left, right)) { *a @@ -1497,14 +1502,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { ); // calculate distance to find upper bound - match b.0.checked_add(c.0) { - Some(added) => Saturating( - added - .checked_sub(a.0.checked_mul(2).unwrap_or(0)) - .unwrap_or(usize::MAX), - ), - None => Saturating(usize::MAX), - } + b + c - (a << 1) //assert_eq!(dist+1, Explanation::new(self.explain_enodes(left, right, &mut Default::default())).make_flat_explanation().len()); } @@ -1517,7 +1515,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { ) -> ProofCost { let current_node = self.node(current).clone(); let next_node = self.node(next).clone(); - let mut cost: ProofCost = Saturating(0); + let mut cost: ProofCost = BigUint::zero(); for (left_child, right_child) in current_node .children() .iter() @@ -1537,7 +1535,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { Justification::Congruence => { self.congruence_distance(connection.current, connection.next, distance_memo) } - Justification::Rule(_) => Saturating(1), + Justification::Rule(_) => BigUint::one(), } } @@ -1549,14 +1547,14 @@ impl<'x, L: Language> ExplainNodes<'x, L> { ) -> ProofCost { loop { let parent = distance_memo.parent_distance[usize::from(enode)].0; - let dist = distance_memo.parent_distance[usize::from(enode)].1; + let dist = distance_memo.parent_distance[usize::from(enode)].1.clone(); if self.parent(parent) == parent { break; } let parent_parent = distance_memo.parent_distance[usize::from(parent)].0; if parent_parent != parent { - let new_dist = dist + distance_memo.parent_distance[usize::from(parent)].1; + let new_dist = dist + distance_memo.parent_distance[usize::from(parent)].1.clone(); distance_memo.parent_distance[usize::from(enode)] = (parent_parent, new_dist); } else { if ancestor == Id::from(usize::MAX) { @@ -1576,7 +1574,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { Justification::Congruence => { self.congruence_distance(current, next, distance_memo) } - Justification::Rule(_) => Saturating(1), + Justification::Rule(_) => BigUint::one(), }; distance_memo.parent_distance[usize::from(parent)] = (self.parent(parent), cost); } @@ -1585,7 +1583,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { //assert_eq!(distance_memo.parent_distance[usize::from(enode)].1+1, //Explanation::new(self.explain_enodes(enode, distance_memo.parent_distance[usize::from(enode)].0, &mut Default::default())).make_flat_explanation().len()); - distance_memo.parent_distance[usize::from(enode)].1 + distance_memo.parent_distance[usize::from(enode)].1.clone() } fn find_congruence_neighbors>( @@ -1662,7 +1660,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { ) -> Option<(Vec, Vec)> { let mut todo = BinaryHeap::new(); todo.push(HeapState { - cost: Saturating(0), + cost: BigUint::zero(), item: Connection { current: start, next: start, @@ -1680,14 +1678,14 @@ impl<'x, L: Language> ExplainNodes<'x, L> { } let state = todo.pop().unwrap(); let connection = state.item; - let cost_so_far = state.cost; + let cost_so_far = state.cost.clone(); let current = connection.next; if last.get(¤t).is_some() { continue 'outer; } else { last.insert(current, connection); - path_cost.insert(current, cost_so_far); + path_cost.insert(current, cost_so_far.clone()); } if current == end { @@ -1696,7 +1694,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { for neighbor in &self.explainfind[usize::from(current)].neighbors { if let Justification::Rule(_) = neighbor.justification { - let neighbor_cost = cost_so_far + Saturating(1); + let neighbor_cost = cost_so_far.clone() + 1_u32; todo.push(HeapState { item: neighbor.clone(), cost: neighbor_cost, @@ -1707,7 +1705,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { for other in congruence_neighbors[usize::from(current)].iter() { let next = other; let distance = self.congruence_distance(current, *next, distance_memo); - let next_cost = cost_so_far + distance; + let next_cost = cost_so_far.clone() + distance; todo.push(HeapState { item: Connection { current, @@ -1928,7 +1926,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { ) { let mut congruence_neighbors = vec![vec![]; self.explainfind.len()]; self.find_congruence_neighbors::(classes, &mut congruence_neighbors, unionfind); - let mut parent_distance = vec![(Id::from(0), Saturating(0)); self.explainfind.len()]; + let mut parent_distance = vec![(Id::from(0), BigUint::zero()); self.explainfind.len()]; for (i, entry) in parent_distance.iter_mut().enumerate() { entry.0 = Id::from(i); } diff --git a/src/test.rs b/src/test.rs index 10815d66..4784e816 100644 --- a/src/test.rs +++ b/src/test.rs @@ -3,10 +3,9 @@ These are not considered part of the public api. */ +use num_traits::identities::Zero; use std::{fmt::Display, fs::File, io::Write, path::PathBuf}; -use saturating::Saturating; - use crate::*; pub fn env_var(s: &str) -> Option @@ -104,14 +103,14 @@ pub fn test_runner( let flattened = explained.make_flat_explanation().clone(); let vanilla_len = flattened.len(); explained.check_proof(rules); - assert!(explained.get_tree_size() > Saturating(0)); + assert!(!explained.get_tree_size().is_zero()); runner = runner.with_explanation_length_optimization(); let mut explained_short = runner.explain_matches(&start, &goal.ast, &subst); explained_short.get_string_with_let(); let short_len = explained_short.get_flat_strings().len(); assert!(short_len <= vanilla_len); - assert!(explained_short.get_tree_size() > Saturating(0)); + assert!(!explained_short.get_tree_size().is_zero()); explained_short.check_proof(rules); } }