Skip to content

Commit

Permalink
Merge branch 'egraphs-good:main' into multipattern-explain
Browse files Browse the repository at this point in the history
  • Loading branch information
eytans authored Jun 21, 2024
2 parents 38ebcb6 + ae2db37 commit 4193d51
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 56 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## [Unreleased] - ReleaseDate
- Change the API of `make` to have mutable access to the e-graph for some [advanced uses cases](https://github.com/egraphs-good/egg/pull/277).
- Fix an e-matching performance regression introduced in [this commit](https://github.com/egraphs-good/egg/commit/ae8af8815231e4aba1b78962f8c07ce837ee1c0e#diff-1d06da761111802c793c6e5ca704bfa0d6336d0becf87fddff02d81548a838ab).
- Use `quanta` instead of `instant` crate to provide timing information. This can provide a huge speedup if you have lots of rules, since it avoids some syscalls.

## [0.9.5] - 2023-06-29
- Fixed a few edge cases in proof size optimization that caused egg to crash.
Expand Down
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ env_logger = { version = "0.9.0", default-features = false }
fxhash = "0.2.1"
hashbrown = "0.12.1"
indexmap = "1.8.1"
instant = "0.1.12"
quanta = "0.12"
log = "0.4.17"
smallvec = { version = "1.8.0", features = ["union", "const_generics"] }
symbol_table = { version = "0.2.0", features = ["global"] }
symbolic_expressions = "5.0.3"
thiserror = "1.0.31"
num-bigint = "0.4"
num-traits = "0.2"

# for the lp feature
coin_cbc = { version = "0.1.6", optional = true }
Expand Down Expand Up @@ -48,7 +50,7 @@ serde-1 = [
"symbol_table/serde",
"vectorize",
]
wasm-bindgen = ["instant/wasm-bindgen"]
wasm-bindgen = []

# private features for testing
test-explanations = []
Expand Down
2 changes: 1 addition & 1 deletion rust-toolchain
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.60
1.60
84 changes: 41 additions & 43 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
util::pretty_print, Analysis, EClass, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language,
PatternAst, RecExpr, Rewrite, UnionFind, Var,
};
use saturating::Saturating;

use std::cmp::Ordering;
use std::collections::{BinaryHeap, VecDeque};
use std::fmt::{self, Debug, Display, Formatter};
Expand All @@ -12,7 +12,10 @@ use std::rc::Rc;

use symbolic_expressions::Sexp;

type ProofCost = Saturating<usize>;
use num_bigint::BigUint;
use num_traits::identities::{One, Zero};

type ProofCost = BigUint;

const CONGRUENCE_LIMIT: usize = 2;
const GREEDY_NUM_ITERS: usize = 2;
Expand Down Expand Up @@ -252,7 +255,7 @@ impl<L: Language + Display + FromOp> Explanation<L> {
pub fn get_tree_size(&self) -> ProofCost {
let mut seen = Default::default();
let mut seen_adjacent = Default::default();
let mut sum: ProofCost = Saturating(0);
let mut sum: ProofCost = BigUint::zero();
for e in self.explanation_trees.iter() {
sum += self.tree_size(&mut seen, &mut seen_adjacent, e);
}
Expand All @@ -266,19 +269,19 @@ impl<L: Language + Display + FromOp> Explanation<L> {
current: &Rc<TreeTerm<L>>,
) -> ProofCost {
if !seen.insert(&**current as *const TreeTerm<L>) {
return Saturating(0);
return BigUint::zero();
}
let mut my_size: ProofCost = Saturating(0);
let mut my_size: ProofCost = BigUint::zero();
if current.forward_rule.is_some() {
my_size += Saturating(1);
my_size += 1_u32;
}
if current.backward_rule.is_some() {
my_size += Saturating(1);
my_size += 1_u32;
}
assert!(my_size <= Saturating(1));
if my_size == Saturating(1) {
assert!(my_size.is_zero() || my_size.is_one());
if my_size.is_one() {
if !seen_adjacent.insert((current.current, current.last)) {
return Saturating(0);
return BigUint::zero();
} else {
seen_adjacent.insert((current.last, current.current));
}
Expand Down Expand Up @@ -867,7 +870,7 @@ impl<L: Language> FlatTerm<L> {
}

// Make sure to use push_increase instead of push when using priority queue
#[derive(Copy, Clone, Eq, PartialEq)]
#[derive(Clone, Eq, PartialEq)]
struct HeapState<I> {
cost: ProofCost,
item: I,
Expand Down Expand Up @@ -954,7 +957,7 @@ impl<L: Language> Explain<L> {
return;
}
if let Some((cost, _)) = self.shortest_explanation_memo.get(&(node1, node2)) {
if cost <= &Saturating(1) {
if cost.is_zero() || cost.is_one() {
return;
}
}
Expand All @@ -980,9 +983,9 @@ impl<L: Language> Explain<L> {
.neighbors
.push(rconnection);
self.shortest_explanation_memo
.insert((node1, node2), (Saturating(1), node2));
.insert((node1, node2), (BigUint::one(), node2));
self.shortest_explanation_memo
.insert((node2, node1), (Saturating(1), node1));
.insert((node2, node1), (BigUint::one(), node1));
}

pub(crate) fn union(
Expand All @@ -1004,9 +1007,9 @@ impl<L: Language> Explain<L> {

if let Justification::Rule(_) = justification {
self.shortest_explanation_memo
.insert((node1, node2), (Saturating(1), node2));
.insert((node1, node2), (BigUint::one(), node2));
self.shortest_explanation_memo
.insert((node2, node1), (Saturating(1), node1));
.insert((node2, node1), (BigUint::one(), node1));
}

let pconnection = Connection {
Expand Down Expand Up @@ -1316,7 +1319,6 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
use_unoptimized: bool,
) -> TreeExplanation<L> {
let mut proof = vec![self.node_to_explanation(left, node_explanation_cache)];

let (left_connections, right_connections) = if use_unoptimized {
self.get_path_unoptimized(left, right)
} else {
Expand Down Expand Up @@ -1420,13 +1422,15 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
if depths.get(&node).is_none() {
let parent = self.parent(node);
let depth = if parent == node {
Saturating(0)
BigUint::zero()
} else {
self.add_tree_depths(parent, depths) + Saturating(1)
self.add_tree_depths(parent, depths) + 1_u32
};

depths.insert(node, depth);
}
return *depths.get(&node).unwrap();

depths.get(&node).unwrap().clone()
}

fn calculate_tree_depths(&self) -> HashMap<Id, ProofCost> {
Expand All @@ -1449,15 +1453,16 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
distance_memo: &mut DistanceMemo,
) {
self.shortest_explanation_memo
.insert((right, right), (Saturating(0), right));
.insert((right, right), (BigUint::zero(), right));
for connection in left_connections.iter().rev() {
let next = connection.next;
let current = connection.current;
let next_cost = self
.shortest_explanation_memo
.get(&(next, right))
.unwrap()
.0;
.0
.clone();
let dist = self.connection_distance(connection, distance_memo);
self.replace_distance(current, next, right, next_cost + dist);
}
Expand All @@ -1470,7 +1475,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
distance_memo: &mut DistanceMemo,
) -> ProofCost {
if left == right {
return Saturating(0);
return BigUint::zero();
}
let ancestor = if let Some(a) = distance_memo.common_ancestor.get(&(left, right)) {
*a
Expand All @@ -1497,14 +1502,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
);

// calculate distance to find upper bound
match b.0.checked_add(c.0) {
Some(added) => Saturating(
added
.checked_sub(a.0.checked_mul(2).unwrap_or(0))
.unwrap_or(usize::MAX),
),
None => Saturating(usize::MAX),
}
b + c - (a << 1)

//assert_eq!(dist+1, Explanation::new(self.explain_enodes(left, right, &mut Default::default())).make_flat_explanation().len());
}
Expand All @@ -1517,7 +1515,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
) -> ProofCost {
let current_node = self.node(current).clone();
let next_node = self.node(next).clone();
let mut cost: ProofCost = Saturating(0);
let mut cost: ProofCost = BigUint::zero();
for (left_child, right_child) in current_node
.children()
.iter()
Expand All @@ -1537,7 +1535,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
Justification::Congruence => {
self.congruence_distance(connection.current, connection.next, distance_memo)
}
Justification::Rule(_) => Saturating(1),
Justification::Rule(_) => BigUint::one(),
}
}

Expand All @@ -1549,14 +1547,14 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
) -> ProofCost {
loop {
let parent = distance_memo.parent_distance[usize::from(enode)].0;
let dist = distance_memo.parent_distance[usize::from(enode)].1;
let dist = distance_memo.parent_distance[usize::from(enode)].1.clone();
if self.parent(parent) == parent {
break;
}

let parent_parent = distance_memo.parent_distance[usize::from(parent)].0;
if parent_parent != parent {
let new_dist = dist + distance_memo.parent_distance[usize::from(parent)].1;
let new_dist = dist + distance_memo.parent_distance[usize::from(parent)].1.clone();
distance_memo.parent_distance[usize::from(enode)] = (parent_parent, new_dist);
} else {
if ancestor == Id::from(usize::MAX) {
Expand All @@ -1576,7 +1574,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
Justification::Congruence => {
self.congruence_distance(current, next, distance_memo)
}
Justification::Rule(_) => Saturating(1),
Justification::Rule(_) => BigUint::one(),
};
distance_memo.parent_distance[usize::from(parent)] = (self.parent(parent), cost);
}
Expand All @@ -1585,7 +1583,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
//assert_eq!(distance_memo.parent_distance[usize::from(enode)].1+1,
//Explanation::new(self.explain_enodes(enode, distance_memo.parent_distance[usize::from(enode)].0, &mut Default::default())).make_flat_explanation().len());

distance_memo.parent_distance[usize::from(enode)].1
distance_memo.parent_distance[usize::from(enode)].1.clone()
}

fn find_congruence_neighbors<N: Analysis<L>>(
Expand Down Expand Up @@ -1662,7 +1660,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
) -> Option<(Vec<Connection>, Vec<Connection>)> {
let mut todo = BinaryHeap::new();
todo.push(HeapState {
cost: Saturating(0),
cost: BigUint::zero(),
item: Connection {
current: start,
next: start,
Expand All @@ -1680,14 +1678,14 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
}
let state = todo.pop().unwrap();
let connection = state.item;
let cost_so_far = state.cost;
let cost_so_far = state.cost.clone();
let current = connection.next;

if last.get(&current).is_some() {
continue 'outer;
} else {
last.insert(current, connection);
path_cost.insert(current, cost_so_far);
path_cost.insert(current, cost_so_far.clone());
}

if current == end {
Expand All @@ -1696,7 +1694,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {

for neighbor in &self.explainfind[usize::from(current)].neighbors {
if let Justification::Rule(_) = neighbor.justification {
let neighbor_cost = cost_so_far + Saturating(1);
let neighbor_cost = cost_so_far.clone() + 1_u32;
todo.push(HeapState {
item: neighbor.clone(),
cost: neighbor_cost,
Expand All @@ -1707,7 +1705,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
for other in congruence_neighbors[usize::from(current)].iter() {
let next = other;
let distance = self.congruence_distance(current, *next, distance_memo);
let next_cost = cost_so_far + distance;
let next_cost = cost_so_far.clone() + distance;
todo.push(HeapState {
item: Connection {
current,
Expand Down Expand Up @@ -1928,7 +1926,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
) {
let mut congruence_neighbors = vec![vec![]; self.explainfind.len()];
self.find_congruence_neighbors::<N>(classes, &mut congruence_neighbors, unionfind);
let mut parent_distance = vec![(Id::from(0), Saturating(0)); self.explainfind.len()];
let mut parent_distance = vec![(Id::from(0), BigUint::zero()); self.explainfind.len()];
for (i, entry) in parent_distance.iter_mut().enumerate() {
entry.0 = Id::from(i);
}
Expand Down
7 changes: 4 additions & 3 deletions src/extract.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::cmp::Ordering;
use std::fmt::Debug;

use crate::util::HashMap;
use crate::util::{hashmap_with_capacity, HashMap};
use crate::{Analysis, EClass, EGraph, Id, Language, RecExpr};

/** Extracting a single [`RecExpr`] from an [`EGraph`].
Expand Down Expand Up @@ -134,8 +134,9 @@ pub trait CostFunction<L: Language> {
/// down the [`RecExpr`].
///
fn cost_rec(&mut self, expr: &RecExpr<L>) -> Self::Cost {
let mut costs: HashMap<Id, Self::Cost> = HashMap::default();
for (i, node) in expr.as_ref().iter().enumerate() {
let nodes = expr.as_ref();
let mut costs = hashmap_with_capacity::<Id, Self::Cost>(nodes.len());
for (i, node) in nodes.iter().enumerate() {
let cost = self.cost(node, |i| costs[&i].clone());
costs.insert(Id::from(i), cost);
}
Expand Down
12 changes: 11 additions & 1 deletion src/language.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ impl<L: Language> RecExpr<L> {
}

pub(crate) fn compact(mut self) -> Self {
let mut ids = HashMap::<Id, Id>::default();
let mut ids = hashmap_with_capacity::<Id, Id>(self.nodes.len());
let mut set = IndexSet::default();
for (i, node) in self.nodes.drain(..).enumerate() {
let node = node.map_children(|id| ids[&id]);
Expand Down Expand Up @@ -765,6 +765,16 @@ pub trait Analysis<L: Language>: Sized {
/// `Analysis::merge` when unions are performed.
#[allow(unused_variables)]
fn modify(egraph: &mut EGraph<L, Self>, id: Id) {}

/// Whether or not e-matching should allow finding cycles.
///
/// By default, this returns `true`.
///
/// Setting this to `false` can improve performance in some cases, but risks
/// missing some equalities depending on the use case.
fn allow_ematching_cycles(&self) -> bool {
true
}
}

impl<L: Language> Analysis<L> for () {
Expand Down
8 changes: 8 additions & 0 deletions src/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,14 @@ impl<L: Language> Program<L> {
&self.instructions,
&self.subst,
&mut |machine, subst| {
if !egraph.analysis.allow_ematching_cycles() {
if let Some((first, rest)) = machine.reg.split_first() {
if rest.contains(first) {
return Ok(());
}
}
}

let subst_vec = subst
.vec
.iter()
Expand Down
Loading

0 comments on commit 4193d51

Please sign in to comment.