Skip to content

Commit

Permalink
Add Existance Explanations (egraphs-good#119)
Browse files Browse the repository at this point in the history
* implement existance explanations

* update docs
  • Loading branch information
oflatt authored Nov 16, 2021
1 parent c637bbd commit 1d4687a
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 20 deletions.
100 changes: 86 additions & 14 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub struct EGraph<L: Language, N: Analysis<L>> {
unionfind: UnionFind,
#[cfg_attr(feature = "serde-1", serde(with = "vectorize"))]
memo: HashMap<L, Id>,
to_union: Vec<(Id, Id, Option<Symbol>)>,
to_union: Vec<(Id, Id, Option<Symbol>, bool)>,
pending: Vec<(L, Id)>,
analysis_pending: IndexSet<(L, Id)>,
#[cfg_attr(
Expand Down Expand Up @@ -212,6 +212,35 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
}
}

/// When explanations are enabled, this function
/// produces an [`Explanation`] describing how the given expression came
/// to be in the egraph.
///
/// The [`Explanation`] begins with some expression that was added directly
/// into the egraph and ends with the given `expr`.
/// Note that this function can be called again to explain any intermediate terms
/// used in the output [`Explanation`].
pub fn explain_existance(&mut self, expr: &RecExpr<L>) -> Explanation<L> {
if let Some(explain) = &mut self.explain {
explain.explain_existance(expr, &self.memo, &mut self.unionfind)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
}

/// Return an [`Explanation`] for why a pattern appears in the egraph.
pub fn explain_existance_pattern(
&mut self,
pattern: &PatternAst<L>,
subst: &Subst,
) -> Explanation<L> {
if let Some(explain) = &mut self.explain {
explain.explain_existance_pattern(pattern, subst, &self.memo, &mut self.unionfind)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
}

/// Get an explanation for why an expression matches a pattern.
pub fn explain_matches(
&mut self,
Expand Down Expand Up @@ -304,9 +333,25 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
pub fn add_expr(&mut self, expr: &RecExpr<L>) -> Id {
let nodes = expr.as_ref();
let mut new_ids = Vec::with_capacity(nodes.len());
let mut new_node_q = Vec::with_capacity(nodes.len());
for node in nodes {
let node = node.clone().map_children(|i| new_ids[usize::from(i)]);
new_ids.push(self.add(node))
let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]);
let size_before = self.unionfind.size();
let next_id = self.add(new_node);
if self.unionfind.size() > size_before {
new_node_q.push(true);
} else {
new_node_q.push(false);
}
if let Some(explain) = &mut self.explain {
node.for_each(|child| {
// Set the existance reason for new nodes to their parent node.
if new_node_q[usize::from(child)] {
explain.set_existance_reason(new_ids[usize::from(child)], next_id);
}
});
}
new_ids.push(next_id);
}
*new_ids.last().unwrap()
}
Expand All @@ -315,15 +360,31 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
pub fn add_instantiation(&mut self, pat: &PatternAst<L>, subst: &Subst) -> Id {
let nodes = pat.as_ref().as_ref();
let mut new_ids = Vec::with_capacity(nodes.len());
let mut new_node_q = Vec::with_capacity(nodes.len());
for node in nodes {
match node {
ENodeOrVar::Var(var) => {
let id = subst[*var];
new_ids.push(id);
new_node_q.push(false);
}
ENodeOrVar::ENode(node) => {
let node = node.clone().map_children(|i| new_ids[usize::from(i)]);
new_ids.push(self.add(node))
let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]);
let size_before = self.unionfind.size();
let next_id = self.add(new_node);
if self.unionfind.size() > size_before {
new_node_q.push(true);
} else {
new_node_q.push(false);
}
if let Some(explain) = &mut self.explain {
node.for_each(|child| {
if new_node_q[usize::from(child)] {
explain.set_existance_reason(new_ids[usize::from(child)], next_id);
}
});
}
new_ids.push(next_id);
}
}
}
Expand Down Expand Up @@ -394,7 +455,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
self.lookup(&mut enode).unwrap_or_else(|| {
let id = self.unionfind.make_set();
if let Some(explain) = &mut self.explain {
explain.add(enode.clone(), id);
explain.add(enode.clone(), id, id);
}
log::trace!(" ...adding to {}", id);
let class = EClass {
Expand Down Expand Up @@ -465,10 +526,12 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
rule_name: impl Into<Symbol>,
) -> (Id, bool) {
let id1 = self.add_instantiation(from_pat, subst);
let size_before = self.unionfind.size();
let id2 = self.add_instantiation(to_pat, subst);
let rhs_new = self.unionfind.size() > size_before;
(
id1,
self.union_with_justification(id1, id2, from_pat, to_pat, subst, rule_name),
self.union_with_justification(id1, id2, from_pat, to_pat, subst, rule_name, rhs_new),
)
}

Expand All @@ -480,6 +543,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
to_pat: &PatternAst<L>,
subst: &Subst,
rule_name: impl Into<Symbol>,
rhs_new: bool,
) -> bool {
self.clean = false;
if let Some(explain) = &mut self.explain {
Expand All @@ -488,9 +552,11 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
} else {
let left_added =
explain.add_match(from_pat, subst, &self.memo, &mut self.unionfind);
let size_before_right = self.unionfind.size();
let right_added = explain.add_match(to_pat, subst, &self.memo, &mut self.unionfind);
let any_new_rhs = rhs_new || self.unionfind.size() > size_before_right;
self.to_union
.push((left_added, right_added, Some(rule_name.into())));
.push((left_added, right_added, Some(rule_name.into()), any_new_rhs));
true
}
} else {
Expand Down Expand Up @@ -523,12 +589,18 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
if self.find_mut(id1) == self.find_mut(id2) {
false
} else {
self.to_union.push((id1, id2, None));
self.to_union.push((id1, id2, None, false));
true
}
}

fn perform_union(&mut self, enode_id1: Id, enode_id2: Id, rule: Option<Justification>) -> bool {
fn perform_union(
&mut self,
enode_id1: Id,
enode_id2: Id,
rule: Option<Justification>,
any_new_rhs: bool,
) -> bool {
let mut id1 = self.find_mut(enode_id1);
let mut id2 = self.find_mut(enode_id2);
if id1 == id2 {
Expand All @@ -544,7 +616,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
N::pre_union(self, id1, id2);

if let Some(explain) = &mut self.explain {
explain.union(enode_id1, enode_id2, rule.unwrap());
explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs);
} else {
assert!(rule.is_none());
}
Expand Down Expand Up @@ -703,8 +775,8 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
while !self.to_union.is_empty() {
let mut current = vec![];
std::mem::swap(&mut self.to_union, &mut current);
for (id1, id2, rule) in current.into_iter() {
self.perform_union(id1, id2, rule.map(Justification::Rule));
for (id1, id2, rule, any_new_rhs) in current.into_iter() {
self.perform_union(id1, id2, rule.map(Justification::Rule), any_new_rhs);
}
}
}
Expand All @@ -728,7 +800,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
if self.explain.is_some() {
reason = Some(Justification::Congruence);
}
let did_something = self.perform_union(memo_class, class, reason);
let did_something = self.perform_union(memo_class, class, reason, false);
n_unions += did_something as usize;
}
}
Expand Down
135 changes: 131 additions & 4 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ struct ExplainNode<L: Language> {
next: Id,
current: Id,
justification: Justification,
// it was inserted because of:
// 1) it's parent is inserted (points to parent enode)
// 2) a rewrite instantiated it (points to adjacent enode)
// 3) it was inserted directly (points to itself)
// if 1 is true but it's also adjacent (2) then either works and it picks 2
existance_node: Id,
is_rewrite_forward: bool,
}

Expand Down Expand Up @@ -752,6 +758,22 @@ impl<L: Language> Explain<L> {
let rule_table = Explain::make_rule_table(rules);
for i in 0..self.explainfind.len() {
let explain_node = &self.explainfind[i];

// check that explanation reasons never form a cycle
let mut existance = i;
let mut seen_existance: HashSet<usize> = Default::default();
loop {
seen_existance.insert(existance);
let next = usize::from(self.explainfind[existance].existance_node);
if existance == next {
break;
}
existance = next;
if seen_existance.contains(&existance) {
assert!(false, "Cycle in existance!");
}
}

if explain_node.next != Id::from(i) {
let mut current_explanation = self.node_to_flat_explanation(Id::from(i));
let mut next_explanation = self.node_to_flat_explanation(explain_node.next);
Expand Down Expand Up @@ -781,13 +803,18 @@ impl<L: Language> Explain<L> {
}
}

pub fn add(&mut self, node: L, set: Id) -> Id {
pub(crate) fn set_existance_reason(&mut self, node: Id, existance_node: Id) {
self.explainfind[usize::from(node)].existance_node = existance_node;
}

pub(crate) fn add(&mut self, node: L, set: Id, existance_node: Id) -> Id {
self.uncanon_memo.insert(node.clone(), set);
self.explainfind.push(ExplainNode {
node,
justification: Justification::Congruence,
next: set,
current: set,
existance_node,
is_rewrite_forward: false,
});
set
Expand Down Expand Up @@ -853,11 +880,18 @@ impl<L: Language> Explain<L> {
.map_children(|id| unionfind.find(id))
);

let new_congruent_id = self.add(new_congruent_node, unionfind.make_set());
let new_congruent_id =
self.add(new_congruent_node, unionfind.make_set(), congruent_id);

match_ids.push(new_congruent_id);
// make the congruent_id we found the leader
unionfind.union(congruent_class, new_congruent_id);
self.union(new_congruent_id, congruent_id, Justification::Congruence);
self.union(
new_congruent_id,
congruent_id,
Justification::Congruence,
false,
);
}
}
}
Expand All @@ -880,7 +914,17 @@ impl<L: Language> Explain<L> {
}
}

pub(crate) fn union(&mut self, node1: Id, node2: Id, justification: Justification) {
pub(crate) fn union(
&mut self,
node1: Id,
node2: Id,
justification: Justification,
new_rhs: bool,
) {
if new_rhs {
self.set_existance_reason(node2, node1)
}

self.make_leader(node1);
self.explainfind[usize::from(node1)].next = node2;
self.explainfind[usize::from(node1)].justification = justification;
Expand Down Expand Up @@ -914,6 +958,37 @@ impl<L: Language> Explain<L> {
Explanation::new(self.explain_enodes(left_added, right_added, &mut cache))
}

pub(crate) fn explain_existance(
&mut self,
left: &RecExpr<L>,
memo: &HashMap<L, Id>,
unionfind: &mut UnionFind,
) -> Explanation<L> {
let left_added = self.add_expr(left, memo, unionfind);
let mut cache = Default::default();
Explanation::new(self.explain_enode_existance(
left_added,
Rc::new(self.node_to_explanation(left_added)),
&mut cache,
))
}

pub(crate) fn explain_existance_pattern(
&mut self,
left: &PatternAst<L>,
subst: &Subst,
memo: &HashMap<L, Id>,
unionfind: &mut UnionFind,
) -> Explanation<L> {
let left_added = self.add_match(left, &subst, memo, unionfind);
let mut cache = Default::default();
Explanation::new(self.explain_enode_existance(
left_added,
Rc::new(self.node_to_explanation(left_added)),
&mut cache,
))
}

fn common_ancestor(&self, mut left: Id, mut right: Id) -> Id {
let mut seen_left: HashSet<Id> = Default::default();
let mut seen_right: HashSet<Id> = Default::default();
Expand Down Expand Up @@ -953,6 +1028,58 @@ impl<L: Language> Explain<L> {
}
}

fn explain_enode_existance(
&self,
node: Id,
rest_of_proof: Rc<TreeTerm<L>>,
cache: &mut ExplainCache<L>,
) -> TreeExplanation<L> {
let graphnode = &self.explainfind[usize::from(node)];
let existance = graphnode.existance_node;
let existance_node = &self.explainfind[usize::from(existance)];
// case 1)
if existance == node {
return vec![Rc::new(self.node_to_explanation(node)), rest_of_proof];
}

// case 2)
if graphnode.next == existance || existance_node.next == node {
let direction;
let justification;
if graphnode.next == existance {
direction = !graphnode.is_rewrite_forward;
justification = &graphnode.justification;
} else {
direction = existance_node.is_rewrite_forward;
justification = &existance_node.justification;
}
return self.explain_enode_existance(
existance,
self.explain_adjacent(existance, node, direction, justification, cache),
cache,
);
}

// case 3)
let mut new_rest_of_proof = self.node_to_explanation(existance);
let mut index_of_child = 0;
let mut found = false;
existance_node.node.for_each(|child| {
if found {
return;
}
if child == node {
found = true;
} else {
index_of_child += 1;
}
});
assert!(found);
new_rest_of_proof.child_proofs[index_of_child].push(rest_of_proof);

self.explain_enode_existance(existance, Rc::new(new_rest_of_proof), cache)
}

fn explain_enodes(
&self,
left: Id,
Expand Down
Loading

0 comments on commit 1d4687a

Please sign in to comment.