Skip to content

Commit

Permalink
Fix proofs: switch to BigUint instead of Saturating<usize> for pr…
Browse files Browse the repository at this point in the history
…oof cost (#310)

* switch to biguint instead of saturating usize for proof cost

* nits
  • Loading branch information
bksaiki authored Apr 24, 2024
1 parent 2f1514c commit 556a6b3
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 48 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ smallvec = { version = "1.8.0", features = ["union", "const_generics"] }
symbol_table = { version = "0.2.0", features = ["global"] }
symbolic_expressions = "5.0.3"
thiserror = "1.0.31"
num-bigint = "0.4"
num-traits = "0.2"

# for the lp feature
coin_cbc = { version = "0.1.6", optional = true }
Expand Down
2 changes: 1 addition & 1 deletion rust-toolchain
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.60
1.60
84 changes: 41 additions & 43 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
util::pretty_print, Analysis, EClass, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language,
PatternAst, RecExpr, Rewrite, UnionFind, Var,
};
use saturating::Saturating;

use std::cmp::Ordering;
use std::collections::{BinaryHeap, VecDeque};
use std::fmt::{self, Debug, Display, Formatter};
Expand All @@ -12,7 +12,10 @@ use std::rc::Rc;

use symbolic_expressions::Sexp;

type ProofCost = Saturating<usize>;
use num_bigint::BigUint;
use num_traits::identities::{One, Zero};

type ProofCost = BigUint;

const CONGRUENCE_LIMIT: usize = 2;
const GREEDY_NUM_ITERS: usize = 2;
Expand Down Expand Up @@ -252,7 +255,7 @@ impl<L: Language + Display + FromOp> Explanation<L> {
pub fn get_tree_size(&self) -> ProofCost {
let mut seen = Default::default();
let mut seen_adjacent = Default::default();
let mut sum: ProofCost = Saturating(0);
let mut sum: ProofCost = BigUint::zero();
for e in self.explanation_trees.iter() {
sum += self.tree_size(&mut seen, &mut seen_adjacent, e);
}
Expand All @@ -266,19 +269,19 @@ impl<L: Language + Display + FromOp> Explanation<L> {
current: &Rc<TreeTerm<L>>,
) -> ProofCost {
if !seen.insert(&**current as *const TreeTerm<L>) {
return Saturating(0);
return BigUint::zero();
}
let mut my_size: ProofCost = Saturating(0);
let mut my_size: ProofCost = BigUint::zero();
if current.forward_rule.is_some() {
my_size += Saturating(1);
my_size += 1_u32;
}
if current.backward_rule.is_some() {
my_size += Saturating(1);
my_size += 1_u32;
}
assert!(my_size <= Saturating(1));
if my_size == Saturating(1) {
assert!(my_size.is_zero() || my_size.is_one());
if my_size.is_one() {
if !seen_adjacent.insert((current.current, current.last)) {
return Saturating(0);
return BigUint::zero();
} else {
seen_adjacent.insert((current.last, current.current));
}
Expand Down Expand Up @@ -867,7 +870,7 @@ impl<L: Language> FlatTerm<L> {
}

// Make sure to use push_increase instead of push when using priority queue
#[derive(Copy, Clone, Eq, PartialEq)]
#[derive(Clone, Eq, PartialEq)]
struct HeapState<I> {
cost: ProofCost,
item: I,
Expand Down Expand Up @@ -954,7 +957,7 @@ impl<L: Language> Explain<L> {
return;
}
if let Some((cost, _)) = self.shortest_explanation_memo.get(&(node1, node2)) {
if cost <= &Saturating(1) {
if cost.is_zero() || cost.is_one() {
return;
}
}
Expand All @@ -980,9 +983,9 @@ impl<L: Language> Explain<L> {
.neighbors
.push(rconnection);
self.shortest_explanation_memo
.insert((node1, node2), (Saturating(1), node2));
.insert((node1, node2), (BigUint::one(), node2));
self.shortest_explanation_memo
.insert((node2, node1), (Saturating(1), node1));
.insert((node2, node1), (BigUint::one(), node1));
}

pub(crate) fn union(
Expand All @@ -1004,9 +1007,9 @@ impl<L: Language> Explain<L> {

if let Justification::Rule(_) = justification {
self.shortest_explanation_memo
.insert((node1, node2), (Saturating(1), node2));
.insert((node1, node2), (BigUint::one(), node2));
self.shortest_explanation_memo
.insert((node2, node1), (Saturating(1), node1));
.insert((node2, node1), (BigUint::one(), node1));
}

let pconnection = Connection {
Expand Down Expand Up @@ -1316,7 +1319,6 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
use_unoptimized: bool,
) -> TreeExplanation<L> {
let mut proof = vec![self.node_to_explanation(left, node_explanation_cache)];

let (left_connections, right_connections) = if use_unoptimized {
self.get_path_unoptimized(left, right)
} else {
Expand Down Expand Up @@ -1420,13 +1422,15 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
if depths.get(&node).is_none() {
let parent = self.parent(node);
let depth = if parent == node {
Saturating(0)
BigUint::zero()
} else {
self.add_tree_depths(parent, depths) + Saturating(1)
self.add_tree_depths(parent, depths) + 1_u32
};

depths.insert(node, depth);
}
return *depths.get(&node).unwrap();

depths.get(&node).unwrap().clone()
}

fn calculate_tree_depths(&self) -> HashMap<Id, ProofCost> {
Expand All @@ -1449,15 +1453,16 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
distance_memo: &mut DistanceMemo,
) {
self.shortest_explanation_memo
.insert((right, right), (Saturating(0), right));
.insert((right, right), (BigUint::zero(), right));
for connection in left_connections.iter().rev() {
let next = connection.next;
let current = connection.current;
let next_cost = self
.shortest_explanation_memo
.get(&(next, right))
.unwrap()
.0;
.0
.clone();
let dist = self.connection_distance(connection, distance_memo);
self.replace_distance(current, next, right, next_cost + dist);
}
Expand All @@ -1470,7 +1475,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
distance_memo: &mut DistanceMemo,
) -> ProofCost {
if left == right {
return Saturating(0);
return BigUint::zero();
}
let ancestor = if let Some(a) = distance_memo.common_ancestor.get(&(left, right)) {
*a
Expand All @@ -1497,14 +1502,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
);

// calculate distance to find upper bound
match b.0.checked_add(c.0) {
Some(added) => Saturating(
added
.checked_sub(a.0.checked_mul(2).unwrap_or(0))
.unwrap_or(usize::MAX),
),
None => Saturating(usize::MAX),
}
b + c - (a << 1)

//assert_eq!(dist+1, Explanation::new(self.explain_enodes(left, right, &mut Default::default())).make_flat_explanation().len());
}
Expand All @@ -1517,7 +1515,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
) -> ProofCost {
let current_node = self.node(current).clone();
let next_node = self.node(next).clone();
let mut cost: ProofCost = Saturating(0);
let mut cost: ProofCost = BigUint::zero();
for (left_child, right_child) in current_node
.children()
.iter()
Expand All @@ -1537,7 +1535,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
Justification::Congruence => {
self.congruence_distance(connection.current, connection.next, distance_memo)
}
Justification::Rule(_) => Saturating(1),
Justification::Rule(_) => BigUint::one(),
}
}

Expand All @@ -1549,14 +1547,14 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
) -> ProofCost {
loop {
let parent = distance_memo.parent_distance[usize::from(enode)].0;
let dist = distance_memo.parent_distance[usize::from(enode)].1;
let dist = distance_memo.parent_distance[usize::from(enode)].1.clone();
if self.parent(parent) == parent {
break;
}

let parent_parent = distance_memo.parent_distance[usize::from(parent)].0;
if parent_parent != parent {
let new_dist = dist + distance_memo.parent_distance[usize::from(parent)].1;
let new_dist = dist + distance_memo.parent_distance[usize::from(parent)].1.clone();
distance_memo.parent_distance[usize::from(enode)] = (parent_parent, new_dist);
} else {
if ancestor == Id::from(usize::MAX) {
Expand All @@ -1576,7 +1574,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
Justification::Congruence => {
self.congruence_distance(current, next, distance_memo)
}
Justification::Rule(_) => Saturating(1),
Justification::Rule(_) => BigUint::one(),
};
distance_memo.parent_distance[usize::from(parent)] = (self.parent(parent), cost);
}
Expand All @@ -1585,7 +1583,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
//assert_eq!(distance_memo.parent_distance[usize::from(enode)].1+1,
//Explanation::new(self.explain_enodes(enode, distance_memo.parent_distance[usize::from(enode)].0, &mut Default::default())).make_flat_explanation().len());

distance_memo.parent_distance[usize::from(enode)].1
distance_memo.parent_distance[usize::from(enode)].1.clone()
}

fn find_congruence_neighbors<N: Analysis<L>>(
Expand Down Expand Up @@ -1662,7 +1660,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
) -> Option<(Vec<Connection>, Vec<Connection>)> {
let mut todo = BinaryHeap::new();
todo.push(HeapState {
cost: Saturating(0),
cost: BigUint::zero(),
item: Connection {
current: start,
next: start,
Expand All @@ -1680,14 +1678,14 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
}
let state = todo.pop().unwrap();
let connection = state.item;
let cost_so_far = state.cost;
let cost_so_far = state.cost.clone();
let current = connection.next;

if last.get(&current).is_some() {
continue 'outer;
} else {
last.insert(current, connection);
path_cost.insert(current, cost_so_far);
path_cost.insert(current, cost_so_far.clone());
}

if current == end {
Expand All @@ -1696,7 +1694,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {

for neighbor in &self.explainfind[usize::from(current)].neighbors {
if let Justification::Rule(_) = neighbor.justification {
let neighbor_cost = cost_so_far + Saturating(1);
let neighbor_cost = cost_so_far.clone() + 1_u32;
todo.push(HeapState {
item: neighbor.clone(),
cost: neighbor_cost,
Expand All @@ -1707,7 +1705,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
for other in congruence_neighbors[usize::from(current)].iter() {
let next = other;
let distance = self.congruence_distance(current, *next, distance_memo);
let next_cost = cost_so_far + distance;
let next_cost = cost_so_far.clone() + distance;
todo.push(HeapState {
item: Connection {
current,
Expand Down Expand Up @@ -1928,7 +1926,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
) {
let mut congruence_neighbors = vec![vec![]; self.explainfind.len()];
self.find_congruence_neighbors::<N>(classes, &mut congruence_neighbors, unionfind);
let mut parent_distance = vec![(Id::from(0), Saturating(0)); self.explainfind.len()];
let mut parent_distance = vec![(Id::from(0), BigUint::zero()); self.explainfind.len()];
for (i, entry) in parent_distance.iter_mut().enumerate() {
entry.0 = Id::from(i);
}
Expand Down
7 changes: 3 additions & 4 deletions src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
These are not considered part of the public api.
*/

use num_traits::identities::Zero;
use std::{fmt::Display, fs::File, io::Write, path::PathBuf};

use saturating::Saturating;

use crate::*;

pub fn env_var<T>(s: &str) -> Option<T>
Expand Down Expand Up @@ -104,14 +103,14 @@ pub fn test_runner<L, A>(
let flattened = explained.make_flat_explanation().clone();
let vanilla_len = flattened.len();
explained.check_proof(rules);
assert!(explained.get_tree_size() > Saturating(0));
assert!(!explained.get_tree_size().is_zero());

runner = runner.with_explanation_length_optimization();
let mut explained_short = runner.explain_matches(&start, &goal.ast, &subst);
explained_short.get_string_with_let();
let short_len = explained_short.get_flat_strings().len();
assert!(short_len <= vanilla_len);
assert!(explained_short.get_tree_size() > Saturating(0));
assert!(!explained_short.get_tree_size().is_zero());
explained_short.check_proof(rules);
}
}
Expand Down

0 comments on commit 556a6b3

Please sign in to comment.