Skip to content

Commit

Permalink
Remove existence explanations (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt authored Dec 31, 2024
1 parent a18666d commit 4747cfe
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 226 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Changes

## [Unreleased] - ReleaseDate
- Removed existence explanations from egg (the `explain_existance` function). This feature was buggy and not well supported. Supporting it fully required many changes, and it is incompatible with analysis. See #332 for more details.
- Change the API of `make` to have mutable access to the e-graph for some [advanced uses cases](https://github.com/egraphs-good/egg/pull/277).
- Fix an e-matching performance regression introduced in [this commit](https://github.com/egraphs-good/egg/commit/ae8af8815231e4aba1b78962f8c07ce837ee1c0e#diff-1d06da761111802c793c6e5ca704bfa0d6336d0becf87fddff02d81548a838ab).
- Use `quanta` instead of `instant` crate to provide timing information. This can provide a huge speedup if you have lots of rules, since it avoids some syscalls.
Expand Down
89 changes: 10 additions & 79 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,43 +509,6 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
}
}

/// When explanations are enabled, this function
/// produces an [`Explanation`] describing how the given expression came
/// to be in the egraph.
///
/// The [`Explanation`] begins with some expression that was added directly
/// into the egraph and ends with the given `expr`.
/// Note that this function can be called again to explain any intermediate terms
/// used in the output [`Explanation`].
pub fn explain_existance(&mut self, expr: &RecExpr<L>) -> Explanation<L> {
let id = self.add_expr_uncanonical(expr);
self.explain_existance_id(id)
}

/// Equivalent to calling [`explain_existance`](EGraph::explain_existance)`(`[`id_to_expr`](EGraph::id_to_expr)`(id))`
/// but more efficient
fn explain_existance_id(&mut self, id: Id) -> Explanation<L> {
if let Some(explain) = &mut self.explain {
explain.with_nodes(&self.nodes).explain_existance(id)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
}

/// Return an [`Explanation`] for why a pattern appears in the egraph.
pub fn explain_existance_pattern(
&mut self,
pattern: &PatternAst<L>,
subst: &Subst,
) -> Explanation<L> {
let id = self.add_instantiation_noncanonical(pattern, subst);
if let Some(explain) = &mut self.explain {
explain.with_nodes(&self.nodes).explain_existance(id)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
}

/// Get an explanation for why an expression matches a pattern.
pub fn explain_matches(
&mut self,
Expand Down Expand Up @@ -873,14 +836,6 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
} else {
new_node_q.push(false);
}
if let Some(explain) = &mut self.explain {
node.for_each(|child| {
// Set the existance reason for new nodes to their parent node.
if new_node_q[usize::from(child)] {
explain.set_existance_reason(new_ids[usize::from(child)], next_id);
}
});
}
new_ids.push(next_id);
}
*new_ids.last().unwrap()
Expand Down Expand Up @@ -919,13 +874,6 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
new_node_q.push(false);
}

if let Some(explain) = &mut self.explain {
node.for_each(|child| {
if new_node_q[usize::from(child)] {
explain.set_existance_reason(new_ids[usize::from(child)], next_id);
}
});
}
new_ids.push(next_id);
}
}
Expand Down Expand Up @@ -1059,11 +1007,11 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
*existing_explain
} else {
let new_id = self.unionfind.make_set();
explain.add(original.clone(), new_id, new_id);
explain.add(original.clone(), new_id);
debug_assert_eq!(Id::from(self.nodes.len()), new_id);
self.nodes.push(original);
self.unionfind.union(id, new_id);
explain.union(existing_id, new_id, Justification::Congruence, true);
explain.union(existing_id, new_id, Justification::Congruence);
new_id
}
} else {
Expand All @@ -1072,7 +1020,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
} else {
let id = self.make_new_eclass(enode, original.clone());
if let Some(explain) = self.explain.as_mut() {
explain.add(original, id, id);
explain.add(original, id);
}

// now that we updated explanations, run the analysis for the new eclass
Expand Down Expand Up @@ -1152,16 +1100,9 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
rule_name: impl Into<Symbol>,
) -> (Id, bool) {
let id1 = self.add_instantiation_noncanonical(from_pat, subst);
let size_before = self.unionfind.size();
let id2 = self.add_instantiation_noncanonical(to_pat, subst);
let rhs_new = self.unionfind.size() > size_before;

let did_union = self.perform_union(
id1,
id2,
Some(Justification::Rule(rule_name.into())),
rhs_new,
);
let did_union = self.perform_union(id1, id2, Some(Justification::Rule(rule_name.into())));
(self.find(id1), did_union)
}

Expand All @@ -1171,7 +1112,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// `Id`s returned by functions like [`add_uncanonical`](EGraph::add_uncanonical) is important
/// to control explanations
pub fn union_trusted(&mut self, from: Id, to: Id, reason: impl Into<Symbol>) -> bool {
self.perform_union(from, to, Some(Justification::Rule(reason.into())), false)
self.perform_union(from, to, Some(Justification::Rule(reason.into())))
}

/// Unions two eclasses given their ids.
Expand All @@ -1194,17 +1135,11 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
let caller = std::panic::Location::caller();
self.union_trusted(id1, id2, caller.to_string())
} else {
self.perform_union(id1, id2, None, false)
self.perform_union(id1, id2, None)
}
}

fn perform_union(
&mut self,
enode_id1: Id,
enode_id2: Id,
rule: Option<Justification>,
any_new_rhs: bool,
) -> bool {
fn perform_union(&mut self, enode_id1: Id, enode_id2: Id, rule: Option<Justification>) -> bool {
N::pre_union(self, enode_id1, enode_id2, &rule);

self.clean = false;
Expand All @@ -1226,7 +1161,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
}

if let Some(explain) = &mut self.explain {
explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs);
explain.union(enode_id1, enode_id2, rule.unwrap());
}

// make id1 the new root
Expand Down Expand Up @@ -1401,12 +1336,8 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
let mut node = self.nodes[usize::from(class)].clone();
node.update_children(|id| self.find_mut(id));
if let Some(memo_class) = self.memo.insert(node, class) {
let did_something = self.perform_union(
memo_class,
class,
Some(Justification::Congruence),
false,
);
let did_something =
self.perform_union(memo_class, class, Some(Justification::Congruence));
n_unions += did_something as usize;
}
}
Expand Down
135 changes: 2 additions & 133 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,6 @@ struct ExplainNode {
// neighbors includes parent connections
neighbors: Vec<Connection>,
parent_connection: Connection,
// it was inserted because of:
// 1) it's parent is inserted (points to parent enode)
// 2) a rewrite instantiated it (points to adjacent enode)
// 3) it was inserted directly (points to itself)
// if 1 is true but it's also adjacent (2) then either works and it picks 2
existance_node: Id,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -914,11 +908,7 @@ impl<L: Language> Explain<L> {
}
}

pub(crate) fn set_existance_reason(&mut self, node: Id, existance_node: Id) {
self.explainfind[usize::from(node)].existance_node = existance_node;
}

pub(crate) fn add(&mut self, node: L, set: Id, existance_node: Id) -> Id {
pub(crate) fn add(&mut self, node: L, set: Id) -> Id {
assert_eq!(self.explainfind.len(), usize::from(set));
self.uncanon_memo.insert(node, set);
self.explainfind.push(ExplainNode {
Expand All @@ -929,7 +919,6 @@ impl<L: Language> Explain<L> {
next: set,
current: set,
},
existance_node,
});
set
}
Expand Down Expand Up @@ -986,19 +975,10 @@ impl<L: Language> Explain<L> {
.insert((node2, node1), (BigUint::one(), node1));
}

pub(crate) fn union(
&mut self,
node1: Id,
node2: Id,
justification: Justification,
new_rhs: bool,
) {
pub(crate) fn union(&mut self, node1: Id, node2: Id, justification: Justification) {
if let Justification::Congruence = justification {
// assert!(self.node(node1).matches(self.node(node2)));
}
if new_rhs {
self.set_existance_reason(node2, node1)
}

self.make_leader(node1);
self.explainfind[usize::from(node1)].parent_connection.next = node2;
Expand Down Expand Up @@ -1103,21 +1083,6 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
for i in 0..self.explainfind.len() {
let explain_node = &self.explainfind[i];

// check that explanation reasons never form a cycle
let mut existance = i;
let mut seen_existance: HashSet<usize> = Default::default();
loop {
seen_existance.insert(existance);
let next = usize::from(self.explainfind[existance].existance_node);
if existance == next {
break;
}
existance = next;
if seen_existance.contains(&existance) {
panic!("Cycle in existance!");
}
}

if explain_node.parent_connection.next != Id::from(i) {
let mut current_explanation = self.node_to_flat_explanation(Id::from(i));
let mut next_explanation =
Expand Down Expand Up @@ -1159,17 +1124,6 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
Explanation::new(self.explain_enodes(left, right, &mut cache, &mut enode_cache, false))
}

pub(crate) fn explain_existance(&mut self, left: Id) -> Explanation<L> {
let mut cache = Default::default();
let mut enode_cache = Default::default();
Explanation::new(self.explain_enode_existance(
left,
self.node_to_explanation(left, &mut enode_cache),
&mut cache,
&mut enode_cache,
))
}

fn common_ancestor(&self, mut left: Id, mut right: Id) -> Id {
let mut seen_left: HashSet<Id> = Default::default();
let mut seen_right: HashSet<Id> = Default::default();
Expand Down Expand Up @@ -1255,62 +1209,6 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
(left_connections, right_connections)
}

fn explain_enode_existance(
&self,
node: Id,
rest_of_proof: Rc<TreeTerm<L>>,
cache: &mut ExplainCache<L>,
enode_cache: &mut NodeExplanationCache<L>,
) -> TreeExplanation<L> {
let graphnode = &self.explainfind[usize::from(node)];
let existance = graphnode.existance_node;
let existance_node = &self.explainfind[usize::from(existance)];
// case 1)
if existance == node {
return vec![self.node_to_explanation(node, enode_cache), rest_of_proof];
}

// case 2)
if graphnode.parent_connection.next == existance
|| existance_node.parent_connection.next == node
{
let mut connection = if graphnode.parent_connection.next == existance {
graphnode.parent_connection.clone()
} else {
existance_node.parent_connection.clone()
};

if graphnode.parent_connection.next == existance {
connection.is_rewrite_forward = !connection.is_rewrite_forward;
std::mem::swap(&mut connection.next, &mut connection.current);
}

let adj = self.explain_adjacent(connection, cache, enode_cache, false);
let mut exp = self.explain_enode_existance(existance, adj, cache, enode_cache);
exp.push(rest_of_proof);
return exp;
}

// case 3)
let mut new_rest_of_proof = (*self.node_to_explanation(existance, enode_cache)).clone();
let mut index_of_child = 0;
let mut found = false;
self.node(existance).for_each(|child| {
if found {
return;
}
if child == node {
found = true;
} else {
index_of_child += 1;
}
});
assert!(found);
new_rest_of_proof.child_proofs[index_of_child].push(rest_of_proof);

self.explain_enode_existance(existance, Rc::new(new_rest_of_proof), cache, enode_cache)
}

fn explain_enodes(
&self,
left: Id,
Expand Down Expand Up @@ -2048,35 +1946,6 @@ mod tests {

egraph.dot().to_dot("target/foo.dot").unwrap();
}

#[test]
fn simple_explain_exists() {
//! Same as previous test, but now I want to make a rewrite add some term and see it exists in
//! more then one step
use crate::SymbolLang;
init_logger();

let rws: Vec<Rewrite<SymbolLang, ()>> =
[rewrite!("makeb"; "a" => "b"), rewrite!("makec"; "b" => "c")].to_vec();
let mut egraph = Runner::default()
.with_explanations_enabled()
.without_explanation_length_optimization()
.with_expr(&"a".parse().unwrap())
.run(&rws)
.egraph;
egraph.rebuild();
let _a: Symbol = "a".parse().unwrap();
let _b: Symbol = "b".parse().unwrap();
let _c: Symbol = "c".parse().unwrap();
let mut exp = egraph.explain_existance(&"c".parse().unwrap());
println!("{:?}", exp.make_flat_explanation());
assert_eq!(
exp.make_flat_explanation().len(),
3,
"Expected 3 steps, got {:?}",
exp.make_flat_explanation()
);
}
}

#[test]
Expand Down
14 changes: 0 additions & 14 deletions src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,20 +489,6 @@ where
self.egraph.explain_equivalence(left, right)
}

/// Calls [`EGraph::explain_existance`](EGraph::explain_existance()).
pub fn explain_existance(&mut self, expr: &RecExpr<L>) -> Explanation<L> {
self.egraph.explain_existance(expr)
}

/// Calls [EGraph::explain_existance_pattern`](EGraph::explain_existance_pattern()).
pub fn explain_existance_pattern(
&mut self,
pattern: &PatternAst<L>,
subst: &Subst,
) -> Explanation<L> {
self.egraph.explain_existance_pattern(pattern, subst)
}

/// Get an explanation for why an expression matches a pattern.
pub fn explain_matches(
&mut self,
Expand Down

0 comments on commit 4747cfe

Please sign in to comment.