Skip to content

Commit

Permalink
Changed union to optionally accept a reason message
Browse files Browse the repository at this point in the history
  • Loading branch information
eytans committed Apr 19, 2024
1 parent c7e9284 commit 8bcbc34
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 31 deletions.
21 changes: 12 additions & 9 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// let x = egraph.add(S::leaf("x"));
/// let y = egraph.add(S::leaf("y"));
/// // only one eclass
/// egraph.union(x, y);
/// egraph.union(x, y, None);
/// egraph.rebuild();
///
/// assert_eq!(egraph.total_size(), 2);
Expand Down Expand Up @@ -258,7 +258,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {

let added = egraph.add(enode);
if let Some(existing) = ids.get(id) {
egraph.union(*existing, added);
egraph.union(*existing, added, Some("Egraph creation from enodes".into()));
} else {
ids.insert(*id, added);
}
Expand Down Expand Up @@ -518,7 +518,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// let y = egraph.add(S::leaf("y"));
/// assert_ne!(egraph.find(x), egraph.find(y));
///
/// egraph.union(x, y);
/// egraph.union(x, y, None);
/// egraph.rebuild();
/// assert_eq!(egraph.find(x), egraph.find(y));
/// ```
Expand Down Expand Up @@ -889,7 +889,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
///
/// // if the query node isn't canonical, and its passed in by &mut instead of owned,
/// // its children will be canonicalized
/// egraph.union(a, b);
/// egraph.union(a, b, None);
/// egraph.rebuild();
/// assert_eq!(egraph.lookup(&mut node_f_ab), Some(id));
/// assert_eq!(node_f_ab, SymbolLang::new("f", vec![a, a]));
Expand Down Expand Up @@ -958,7 +958,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// let mut egraph: EGraph<SymbolLang, ()> = EGraph::default().with_explanations_enabled();
/// let a = egraph.add_uncanonical(SymbolLang::leaf("a"));
/// let b = egraph.add_uncanonical(SymbolLang::leaf("b"));
/// egraph.union(a, b);
/// egraph.union(a, b, None);
/// egraph.rebuild();
///
/// let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a]));
Expand Down Expand Up @@ -1103,10 +1103,13 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// See [`explain_equivalence`](Runner::explain_equivalence) for a more detailed
/// explanation of the feature.
#[track_caller]
pub fn union(&mut self, id1: Id, id2: Id) -> bool {
pub fn union(&mut self, id1: Id, id2: Id, reason: Option<Symbol>) -> bool {
if self.explain.is_some() {
let caller = std::panic::Location::caller();
self.union_trusted(id1, id2, caller.to_string())
let reason = reason.map_or_else(
|| std::panic::Location::caller().to_string(),
|x| x.to_string()
);
self.union_trusted(id1, id2, reason)
} else {
self.perform_union(id1, id2, None, false)
}
Expand Down Expand Up @@ -1369,7 +1372,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// let ay = egraph.add_expr(&"(+ a y)".parse().unwrap());
/// // Union x and y
/// egraph.union(x, y);
/// egraph.union(x, y, None);
/// // Classes: [x y] [ax] [ay] [a]
/// assert_eq!(egraph.find(x), egraph.find(y));
///
Expand Down
2 changes: 1 addition & 1 deletion src/language.rs
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ impl Analysis<SimpleMath> for ConstantFolding {
fn modify(egraph: &mut EGraph<SimpleMath, Self>, id: Id) {
if let Some(i) = egraph[id].data {
let added = egraph.add(SimpleMath::Num(i));
egraph.union(id, added);
egraph.union(id, added, None);
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions src/multipattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ impl<L: Language, A: Analysis<L>> Applier<L, A> for MultiPattern<L> {
&self,
egraph: &mut EGraph<L, A>,
matches: &[SearchMatches<L>],
_rule_name: Symbol,
rule_name: Symbol,
) -> Vec<Id> {
// TODO explanations?
// the ids returned are kinda garbage
Expand All @@ -164,7 +164,7 @@ impl<L: Language, A: Analysis<L>> Applier<L, A> for MultiPattern<L> {
id_buf.resize(p.as_ref().len(), 0.into());
let id1 = crate::pattern::apply_pat(&mut id_buf, p.as_ref(), egraph, &subst);
if let Some(id2) = subst.insert(*v, id1) {
egraph.union(id1, id2);
egraph.union(id1, id2, Some(rule_name));
}
if i == 0 {
added.push(id1)
Expand Down Expand Up @@ -225,7 +225,7 @@ mod tests {
let _ = egraph.add_expr(&"(f a a)".parse().unwrap());
let ab = egraph.add_expr(&"(f a b)".parse().unwrap());
let ac = egraph.add_expr(&"(f a c)".parse().unwrap());
egraph.union(ab, ac);
egraph.union(ab, ac, None);
egraph.rebuild();

let n_matches = |multipattern: &str| -> usize {
Expand Down Expand Up @@ -272,8 +272,8 @@ mod tests {
let x1 = egraph.add_string("(tag x ctx1)");
let y1 = egraph.add_string("(tag y ctx1)");
let z1 = egraph.add_string("(tag z ctx2)");
egraph.union(x1, y1);
egraph.union(y2, z2);
egraph.union(x1, y1, None);
egraph.union(y2, z2, None);
let rules = vec![multi_rewrite!("context-transfer";
"?x = (tag ?a ?ctx1) = (tag ?b ?ctx1),
?t = (lte ?ctx1 ?ctx2),
Expand Down
4 changes: 2 additions & 2 deletions src/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ where
id = id_temp;
} else {
id = apply_pat(&mut id_buf, ast, egraph, subst);
did_something = egraph.union(id, mat.eclass);
did_something = egraph.union(id, mat.eclass, None);
}

if did_something {
Expand Down Expand Up @@ -393,7 +393,7 @@ where
} else {
vec![]
}
} else if egraph.union(eclass, id) {
} else if egraph.union(eclass, id, None) {
vec![eclass]
} else {
vec![]
Expand Down
4 changes: 2 additions & 2 deletions src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ where
/// let b0c0 = egraph.add(Math::Mul([b0, c0]));
/// let a0b0c0 = egraph.add(Math::Add([a0, b0c0]));
/// // Don't forget to union the new node with the matched node!
/// if egraph.union(matched_id, a0b0c0) {
/// if egraph.union(matched_id, a0b0c0, None) {
/// vec![a0b0c0]
/// } else {
/// vec![]
Expand Down Expand Up @@ -636,7 +636,7 @@ mod tests {
}
} else {
let added = egraph.add(S::leaf(&s));
if egraph.union(added, eclass) {
if egraph.union(added, eclass, None) {
vec![eclass]
} else {
vec![]
Expand Down
2 changes: 1 addition & 1 deletion src/tutorials/_02_getting_started.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ let pat: Pattern<SymbolLang> = "(foo ?x ?x)".parse().unwrap();
let matches = pat.search(&egraph);
assert!(matches.is_empty());
egraph.union(a, b);
egraph.union(a, b, None);
// recall that rebuild must be called to "see" the effects of adds or unions
egraph.rebuild();
Expand Down
2 changes: 1 addition & 1 deletion tests/datalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl DatalogExtTrait for EGraph<Lang, ()> {
for e in s.split(',') {
let exp = e.trim().parse().unwrap();
let id = self.add_expr(&exp);
self.union(true_id, id);
self.union(true_id, id, None);
}
}

Expand Down
2 changes: 1 addition & 1 deletion tests/lambda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ impl Analysis<Lambda> for LambdaAnalysis {
);
} else {
let const_id = egraph.add(c.0);
egraph.union(id, const_id);
egraph.union(id, const_id, None);
}
}
}
Expand Down
16 changes: 8 additions & 8 deletions tests/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl Analysis<Math> for ConstantFold {
);
} else {
let added = egraph.add(Math::Constant(c));
egraph.union(id, added);
egraph.union(id, added, None);
}
// to not prune, comment this out
egraph[id].nodes.retain(|n| n.is_leaf());
Expand Down Expand Up @@ -517,13 +517,13 @@ fn test_medium_intersect() {
let a = egraph1.add_expr(&"(sqrt (sin pi))".parse().unwrap());
let b = egraph1.add_expr(&"(* 1 pi)".parse().unwrap());
let pi = egraph1.add_expr(&"pi".parse().unwrap());
egraph1.union(a, b);
egraph1.union(a, pi);
egraph1.union(a, b, None);
egraph1.union(a, pi, None);
let c = egraph1.add_expr(&"(+ pi pi)".parse().unwrap());
egraph1.union(ln, c);
egraph1.union(ln, c, None);
let k = egraph1.add_expr(&"k".parse().unwrap());
let one = egraph1.add_expr(&"1".parse().unwrap());
egraph1.union(k, one);
egraph1.union(k, one, None);
egraph1.rebuild();

assert_eq!(
Expand All @@ -535,11 +535,11 @@ fn test_medium_intersect() {
let ln = egraph2.add_expr(&"(ln 2)".parse().unwrap());
let k = egraph2.add_expr(&"k".parse().unwrap());
let mk1 = egraph2.add_expr(&"(* k 1)".parse().unwrap());
egraph2.union(mk1, k);
egraph2.union(mk1, k, None);
let two = egraph2.add_expr(&"2".parse().unwrap());
egraph2.union(mk1, two);
egraph2.union(mk1, two, None);
let mul2pi = egraph2.add_expr(&"(+ (* 2 pi) (* 2 pi))".parse().unwrap());
egraph2.union(ln, mul2pi);
egraph2.union(ln, mul2pi, None);
egraph2.rebuild();

assert_eq!(
Expand Down
2 changes: 1 addition & 1 deletion tests/prop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ fn prove_something(name: &str, start: &str, rewrites: &[Rewrite], goals: &[&str]
// this is needed for the soundness of lem_imply
let true_id = runner.egraph.add(Prop::Bool(true));
let root = runner.roots[0];
runner.egraph.union(root, true_id);
runner.egraph.union(root, true_id, None);
runner.egraph.rebuild();

let egraph = runner.run(rewrites).egraph;
Expand Down

0 comments on commit 8bcbc34

Please sign in to comment.