From 8bcbc34b7c71bc611e2636036ccf5655cc2f4261 Mon Sep 17 00:00:00 2001 From: Eytan Singher Date: Fri, 19 Apr 2024 14:52:36 +0300 Subject: [PATCH] Changed union to optionally accept a reason message --- src/egraph.rs | 21 ++++++++++++--------- src/language.rs | 2 +- src/multipattern.rs | 10 +++++----- src/pattern.rs | 4 ++-- src/rewrite.rs | 4 ++-- src/tutorials/_02_getting_started.rs | 2 +- tests/datalog.rs | 2 +- tests/lambda.rs | 2 +- tests/math.rs | 16 ++++++++-------- tests/prop.rs | 2 +- 10 files changed, 34 insertions(+), 31 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 472defe7..bc60ec35 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -157,7 +157,7 @@ impl> EGraph { /// let x = egraph.add(S::leaf("x")); /// let y = egraph.add(S::leaf("y")); /// // only one eclass - /// egraph.union(x, y); + /// egraph.union(x, y, None); /// egraph.rebuild(); /// /// assert_eq!(egraph.total_size(), 2); @@ -258,7 +258,7 @@ impl> EGraph { let added = egraph.add(enode); if let Some(existing) = ids.get(id) { - egraph.union(*existing, added); + egraph.union(*existing, added, Some("Egraph creation from enodes".into())); } else { ids.insert(*id, added); } @@ -518,7 +518,7 @@ impl> EGraph { /// let y = egraph.add(S::leaf("y")); /// assert_ne!(egraph.find(x), egraph.find(y)); /// - /// egraph.union(x, y); + /// egraph.union(x, y, None); /// egraph.rebuild(); /// assert_eq!(egraph.find(x), egraph.find(y)); /// ``` @@ -889,7 +889,7 @@ impl> EGraph { /// /// // if the query node isn't canonical, and its passed in by &mut instead of owned, /// // its children will be canonicalized - /// egraph.union(a, b); + /// egraph.union(a, b, None); /// egraph.rebuild(); /// assert_eq!(egraph.lookup(&mut node_f_ab), Some(id)); /// assert_eq!(node_f_ab, SymbolLang::new("f", vec![a, a])); @@ -958,7 +958,7 @@ impl> EGraph { /// let mut egraph: EGraph = EGraph::default().with_explanations_enabled(); /// let a = egraph.add_uncanonical(SymbolLang::leaf("a")); /// let b = egraph.add_uncanonical(SymbolLang::leaf("b")); - /// egraph.union(a, b); + /// egraph.union(a, b, None); /// egraph.rebuild(); /// /// let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a])); @@ -1103,10 +1103,13 @@ impl> EGraph { /// See [`explain_equivalence`](Runner::explain_equivalence) for a more detailed /// explanation of the feature. #[track_caller] - pub fn union(&mut self, id1: Id, id2: Id) -> bool { + pub fn union(&mut self, id1: Id, id2: Id, reason: Option) -> bool { if self.explain.is_some() { - let caller = std::panic::Location::caller(); - self.union_trusted(id1, id2, caller.to_string()) + let reason = reason.map_or_else( + || std::panic::Location::caller().to_string(), + |x| x.to_string() + ); + self.union_trusted(id1, id2, reason) } else { self.perform_union(id1, id2, None, false) } @@ -1369,7 +1372,7 @@ impl> EGraph { /// let ay = egraph.add_expr(&"(+ a y)".parse().unwrap()); /// // Union x and y - /// egraph.union(x, y); + /// egraph.union(x, y, None); /// // Classes: [x y] [ax] [ay] [a] /// assert_eq!(egraph.find(x), egraph.find(y)); /// diff --git a/src/language.rs b/src/language.rs index 6414c63a..b10c1d97 100644 --- a/src/language.rs +++ b/src/language.rs @@ -673,7 +673,7 @@ impl Analysis for ConstantFolding { fn modify(egraph: &mut EGraph, id: Id) { if let Some(i) = egraph[id].data { let added = egraph.add(SimpleMath::Num(i)); - egraph.union(id, added); + egraph.union(id, added, None); } } } diff --git a/src/multipattern.rs b/src/multipattern.rs index 4fe61212..553bfff1 100644 --- a/src/multipattern.rs +++ b/src/multipattern.rs @@ -151,7 +151,7 @@ impl> Applier for MultiPattern { &self, egraph: &mut EGraph, matches: &[SearchMatches], - _rule_name: Symbol, + rule_name: Symbol, ) -> Vec { // TODO explanations? // the ids returned are kinda garbage @@ -164,7 +164,7 @@ impl> Applier for MultiPattern { id_buf.resize(p.as_ref().len(), 0.into()); let id1 = crate::pattern::apply_pat(&mut id_buf, p.as_ref(), egraph, &subst); if let Some(id2) = subst.insert(*v, id1) { - egraph.union(id1, id2); + egraph.union(id1, id2, Some(rule_name)); } if i == 0 { added.push(id1) @@ -225,7 +225,7 @@ mod tests { let _ = egraph.add_expr(&"(f a a)".parse().unwrap()); let ab = egraph.add_expr(&"(f a b)".parse().unwrap()); let ac = egraph.add_expr(&"(f a c)".parse().unwrap()); - egraph.union(ab, ac); + egraph.union(ab, ac, None); egraph.rebuild(); let n_matches = |multipattern: &str| -> usize { @@ -272,8 +272,8 @@ mod tests { let x1 = egraph.add_string("(tag x ctx1)"); let y1 = egraph.add_string("(tag y ctx1)"); let z1 = egraph.add_string("(tag z ctx2)"); - egraph.union(x1, y1); - egraph.union(y2, z2); + egraph.union(x1, y1, None); + egraph.union(y2, z2, None); let rules = vec![multi_rewrite!("context-transfer"; "?x = (tag ?a ?ctx1) = (tag ?b ?ctx1), ?t = (lte ?ctx1 ?ctx2), diff --git a/src/pattern.rs b/src/pattern.rs index 3143e82d..893167cf 100644 --- a/src/pattern.rs +++ b/src/pattern.rs @@ -362,7 +362,7 @@ where id = id_temp; } else { id = apply_pat(&mut id_buf, ast, egraph, subst); - did_something = egraph.union(id, mat.eclass); + did_something = egraph.union(id, mat.eclass, None); } if did_something { @@ -393,7 +393,7 @@ where } else { vec![] } - } else if egraph.union(eclass, id) { + } else if egraph.union(eclass, id, None) { vec![eclass] } else { vec![] diff --git a/src/rewrite.rs b/src/rewrite.rs index 6b34b48b..4b919569 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -309,7 +309,7 @@ where /// let b0c0 = egraph.add(Math::Mul([b0, c0])); /// let a0b0c0 = egraph.add(Math::Add([a0, b0c0])); /// // Don't forget to union the new node with the matched node! -/// if egraph.union(matched_id, a0b0c0) { +/// if egraph.union(matched_id, a0b0c0, None) { /// vec![a0b0c0] /// } else { /// vec![] @@ -636,7 +636,7 @@ mod tests { } } else { let added = egraph.add(S::leaf(&s)); - if egraph.union(added, eclass) { + if egraph.union(added, eclass, None) { vec![eclass] } else { vec![] diff --git a/src/tutorials/_02_getting_started.rs b/src/tutorials/_02_getting_started.rs index 895d63f5..59d8dc45 100644 --- a/src/tutorials/_02_getting_started.rs +++ b/src/tutorials/_02_getting_started.rs @@ -121,7 +121,7 @@ let pat: Pattern = "(foo ?x ?x)".parse().unwrap(); let matches = pat.search(&egraph); assert!(matches.is_empty()); -egraph.union(a, b); +egraph.union(a, b, None); // recall that rebuild must be called to "see" the effects of adds or unions egraph.rebuild(); diff --git a/tests/datalog.rs b/tests/datalog.rs index 2343d681..1671bded 100644 --- a/tests/datalog.rs +++ b/tests/datalog.rs @@ -20,7 +20,7 @@ impl DatalogExtTrait for EGraph { for e in s.split(',') { let exp = e.trim().parse().unwrap(); let id = self.add_expr(&exp); - self.union(true_id, id); + self.union(true_id, id, None); } } diff --git a/tests/lambda.rs b/tests/lambda.rs index 80ea4fbd..cbf39ab1 100644 --- a/tests/lambda.rs +++ b/tests/lambda.rs @@ -108,7 +108,7 @@ impl Analysis for LambdaAnalysis { ); } else { let const_id = egraph.add(c.0); - egraph.union(id, const_id); + egraph.union(id, const_id, None); } } } diff --git a/tests/math.rs b/tests/math.rs index a0d8c07a..12090244 100644 --- a/tests/math.rs +++ b/tests/math.rs @@ -93,7 +93,7 @@ impl Analysis for ConstantFold { ); } else { let added = egraph.add(Math::Constant(c)); - egraph.union(id, added); + egraph.union(id, added, None); } // to not prune, comment this out egraph[id].nodes.retain(|n| n.is_leaf()); @@ -517,13 +517,13 @@ fn test_medium_intersect() { let a = egraph1.add_expr(&"(sqrt (sin pi))".parse().unwrap()); let b = egraph1.add_expr(&"(* 1 pi)".parse().unwrap()); let pi = egraph1.add_expr(&"pi".parse().unwrap()); - egraph1.union(a, b); - egraph1.union(a, pi); + egraph1.union(a, b, None); + egraph1.union(a, pi, None); let c = egraph1.add_expr(&"(+ pi pi)".parse().unwrap()); - egraph1.union(ln, c); + egraph1.union(ln, c, None); let k = egraph1.add_expr(&"k".parse().unwrap()); let one = egraph1.add_expr(&"1".parse().unwrap()); - egraph1.union(k, one); + egraph1.union(k, one, None); egraph1.rebuild(); assert_eq!( @@ -535,11 +535,11 @@ fn test_medium_intersect() { let ln = egraph2.add_expr(&"(ln 2)".parse().unwrap()); let k = egraph2.add_expr(&"k".parse().unwrap()); let mk1 = egraph2.add_expr(&"(* k 1)".parse().unwrap()); - egraph2.union(mk1, k); + egraph2.union(mk1, k, None); let two = egraph2.add_expr(&"2".parse().unwrap()); - egraph2.union(mk1, two); + egraph2.union(mk1, two, None); let mul2pi = egraph2.add_expr(&"(+ (* 2 pi) (* 2 pi))".parse().unwrap()); - egraph2.union(ln, mul2pi); + egraph2.union(ln, mul2pi, None); egraph2.rebuild(); assert_eq!( diff --git a/tests/prop.rs b/tests/prop.rs index ed1c7469..5c97899b 100644 --- a/tests/prop.rs +++ b/tests/prop.rs @@ -112,7 +112,7 @@ fn prove_something(name: &str, start: &str, rewrites: &[Rewrite], goals: &[&str] // this is needed for the soundness of lem_imply let true_id = runner.egraph.add(Prop::Bool(true)); let root = runner.roots[0]; - runner.egraph.union(root, true_id); + runner.egraph.union(root, true_id, None); runner.egraph.rebuild(); let egraph = runner.run(rewrites).egraph;