From 241df810726e59aca9625ad16d5067349ab54a6a Mon Sep 17 00:00:00 2001 From: Eytan Singher Date: Mon, 1 Jul 2024 15:51:26 +0300 Subject: [PATCH] Starting to add support for multipattern. Need to change existance\union to multiple reasons and update apply_matches --- Cargo.toml | 2 + src/egraph.rs | 5 +- src/explain.rs | 283 +++++++++++++++++++++++++++++++++++++++----- src/language.rs | 2 +- src/multipattern.rs | 24 +++- src/run.rs | 1 + src/test.rs | 2 + 7 files changed, 280 insertions(+), 39 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0e0f73e3..31c6149e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ serde_json = { version = "1.0.81", optional = true } saturating = "0.1.0" rayon = { version = "1.10.0", optional = true } crossbeam = { version = "0.8.4", optional = true, features = ["crossbeam-channel"] } +itertools = "0.13.0" [dev-dependencies] ordered-float = "3.0.0" @@ -55,6 +56,7 @@ serde-1 = [ ] wasm-bindgen = [] parallel = ["rayon", "crossbeam"] +check_proof = [] # private features for testing test-explanations = [] diff --git a/src/egraph.rs b/src/egraph.rs index f522505a..ee23493a 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -854,9 +854,10 @@ impl> EGraph { for node in nodes { let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]); let size_before = self.unionfind.size(); - let next_id = self.add_uncanonical(new_node); + let next_id = self.add_uncanonical(new_node.clone()); if self.unionfind.size() > size_before { new_node_q.push(true); + println!("Pushed new node {new_node} to egraph"); } else { new_node_q.push(false); } @@ -1144,7 +1145,6 @@ impl> EGraph { let size_before = self.unionfind.size(); let id2 = self.add_instantiation_noncanonical(to_pat, subst); let rhs_new = self.unionfind.size() > size_before; - let did_union = self.perform_union( id1, id2, @@ -1492,6 +1492,7 @@ impl> EGraph { n_unions } + #[cfg(feature = "check_proof")] pub(crate) fn check_each_explain(&mut self, rules: &[&Rewrite]) -> bool { if let Some(explain) = &mut self.explain { explain.with_nodes(&self.nodes).check_each_explain(rules) diff --git a/src/explain.rs b/src/explain.rs index 08a3ae3a..f1425592 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -1,8 +1,11 @@ use std::cmp::Ordering; use std::collections::{BinaryHeap, VecDeque}; use std::fmt::{self, Debug, Display, Formatter}; +use std::iter::FromIterator; use std::ops::{Deref, DerefMut}; use std::rc::Rc; +use itertools::Itertools; +use log::{debug, info, trace, warn}; use num_bigint::BigUint; use num_traits::identities::{One, Zero}; @@ -108,7 +111,109 @@ pub type TreeExplanation = Vec>>; /// is connected to the previous by exactly one rewrite. /// /// See [`FlatTerm`] for more details on how to find this rewrite. -pub type FlatExplanation = Vec>; +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FlatExplanation(Vec>); + +// implement iterator and index for FlatExplanation +impl Deref for FlatExplanation { + type Target = Vec>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for FlatExplanation { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl FlatExplanation { + /// Check if the explanation is empty. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Get the length of the explanation. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Get the last element of the explanation. + pub fn last(&self) -> Option<&FlatTerm> { + self.0.last() + } + + /// Get the last element of the explanation. + pub fn last_mut(&mut self) -> Option<&mut FlatTerm> { + self.0.last_mut() + } + + /// Get the first element of the explanation. + pub fn first(&self) -> Option<&FlatTerm> { + self.0.first() + } + + /// Get the first element of the explanation. + pub fn first_mut(&mut self) -> Option<&mut FlatTerm> { + self.0.first_mut() + } + + /// Get the last element of the explanation and remove it. + pub fn pop(&mut self) -> Option> { + self.0.pop() + } + + /// Extend the explanation with another. + pub fn extend(&mut self, other: FlatExplanation) { + self.0.extend(other.0); + } +} + +impl IntoIterator for FlatExplanation { + type Item = FlatTerm; + type IntoIter = std::vec::IntoIter>; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a, L: Language> IntoIterator for &'a FlatExplanation { + type Item = &'a FlatTerm; + type IntoIter = std::slice::Iter<'a, FlatTerm>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} + +// For &mut A +impl<'a, L: Language> IntoIterator for &'a mut FlatExplanation { + type Item = &'a mut FlatTerm; + type IntoIter = std::slice::IterMut<'a, FlatTerm>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter_mut() + } +} + +impl FromIterator> for FlatExplanation { + fn from_iter>>(iter: T) -> Self { + FlatExplanation(iter.into_iter().collect()) + } +} + +impl Display for FlatExplanation { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + // print each term in the explanation + // Term with subterms should have them indented + Ok(for term in self.iter() { + write!(f, "{:?}\n", term)?; + }) + } +} /// A vector of equalities based on enode ids. Each entry represents /// two enode ids that are equal and why. @@ -371,7 +476,10 @@ impl Explanation { self.flat_explanation.as_ref().unwrap() } } +} +#[cfg(feature = "check_proof")] +impl Explanation { /// Check the validity of the explanation with respect to the given rules. /// This only is able to check rule applications when the rules are implement `get_pattern_ast`. pub fn check_proof<'a, R, N: Analysis>(&mut self, rules: R) @@ -495,8 +603,13 @@ impl TreeTerm { } fn flatten_proof(proof: &[Rc>]) -> FlatExplanation { - let mut flat_proof: FlatExplanation = vec![]; + println!("Starting to flatten prrof (len is {})", proof.len()); + let proof_parts = proof.iter().join(", "); + println!("Proof parts: {proof_parts}"); + let mut flat_proof: FlatExplanation = FlatExplanation(vec![]); for tree in proof { + let asstr = format!("{}", tree); + println!("Flattening tree {asstr}"); let mut explanation = tree.flatten_explanation(); if !flat_proof.is_empty() @@ -509,6 +622,7 @@ impl TreeTerm { flat_proof.extend(explanation); } + println!("Done flattening"); flat_proof } @@ -546,15 +660,17 @@ impl TreeTerm { let mut proof = vec![]; let mut child_proofs = vec![]; let mut representative_terms = vec![]; + println!("Flattening self: {}", self); for child_explanation in &self.child_proofs { let flat_proof = TreeTerm::flatten_proof(child_explanation); + println!("Flattened child proof: {flat_proof}"); representative_terms.push(flat_proof[0].remove_rewrites()); child_proofs.push(flat_proof); } proof.push(FlatTerm::new( self.node.clone(), - representative_terms.clone(), + FlatExplanation(representative_terms.clone()), )); for (i, child_proof) in child_proofs.iter().enumerate() { @@ -571,7 +687,7 @@ impl TreeTerm { } } - proof.push(FlatTerm::new(self.node.clone(), children)); + proof.push(FlatTerm::new(self.node.clone(), FlatExplanation(children))); } representative_terms[i] = child_proof.last().unwrap().remove_rewrites(); } @@ -579,7 +695,7 @@ impl TreeTerm { proof[0].backward_rule = self.backward_rule; proof[0].forward_rule = self.forward_rule; - proof + FlatExplanation(proof) } } @@ -710,7 +826,7 @@ impl FlatTerm { } } -impl Display for TreeTerm { +impl Display for TreeTerm { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let mut buf = String::new(); let width = 80; @@ -719,7 +835,7 @@ impl Display for TreeTerm { } } -impl TreeTerm { +impl TreeTerm { /// Convert this TreeTerm to an S-expression. fn get_sexp(&self) -> Sexp { self.get_sexp_with_bindings(&Default::default()) @@ -777,17 +893,8 @@ impl TreeTerm { } } +#[cfg(proof_check)] impl FlatTerm { - /// Construct a new FlatTerm given a node and its children. - pub fn new(node: L, children: FlatExplanation) -> FlatTerm { - FlatTerm { - node, - backward_rule: None, - forward_rule: None, - children, - } - } - /// Rewrite the FlatTerm by matching the lhs and substituting the rhs. /// The lhs must be guaranteed to match. pub fn rewrite(&self, lhs: &PatternAst, rhs: &PatternAst) -> FlatTerm { @@ -798,6 +905,19 @@ impl FlatTerm { FlatTerm::from_pattern(rhs_nodes, rhs_nodes.len() - 1, &bindings) } +} + +impl FlatTerm { + /// Construct a new FlatTerm given a node and its children. + pub fn new(node: L, children: FlatExplanation) -> FlatTerm { + FlatTerm { + node, + backward_rule: None, + forward_rule: None, + children, + } + } + /// Checks if this term or any child has a [`forward_rule`](FlatTerm::forward_rule). pub fn has_rewrite_forward(&self) -> bool { self.forward_rule.is_some() @@ -832,7 +952,7 @@ impl FlatTerm { )); acc }); - FlatTerm::new(node.clone(), children) + FlatTerm::new(node.clone(), FlatExplanation(children)) } } } @@ -1071,7 +1191,10 @@ impl<'x, L: Language> ExplainNodes<'x, L> { pub(crate) fn node(&self, node_id: Id) -> &L { &self.nodes[usize::from(node_id)] } - fn node_to_explanation( + + /// Reconstructs the expression represented by node_id into a tree term by using self.node + /// to find the appropriate e-node. + fn expr_reconstruction( &self, node_id: Id, cache: &mut NodeExplanationCache, @@ -1080,8 +1203,11 @@ impl<'x, L: Language> ExplainNodes<'x, L> { existing.clone() } else { let node = self.node(node_id).clone(); + let node_text = format!("{}", node); + let direct_children_text = node.children().iter().map(|c| format!("{}",self.node(*c))).join(", "); + trace!("Node text: {node_text} - {direct_children_text}"); let children = node.fold(vec![], |mut sofar, child| { - sofar.push(vec![self.node_to_explanation(child, cache)]); + sofar.push(vec![self.expr_reconstruction(child, cache)]); sofar }); let res = Rc::new(TreeTerm::new(node, children)); @@ -1096,9 +1222,10 @@ impl<'x, L: Language> ExplainNodes<'x, L> { sofar.push(self.node_to_flat_explanation(child)); sofar }); - FlatTerm::new(node, children) + FlatTerm::new(node, FlatExplanation(children)) } + #[cfg(feature = "check_proof")] pub fn check_each_explain>(&self, rules: &[&Rewrite]) -> bool { let rule_table = Explain::make_rule_table(rules); for i in 0..self.explainfind.len() { @@ -1165,7 +1292,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { let mut enode_cache = Default::default(); Explanation::new(self.explain_enode_existance( left, - self.node_to_explanation(left, &mut enode_cache), + self.expr_reconstruction(left, &mut enode_cache), &mut cache, &mut enode_cache, )) @@ -1259,7 +1386,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { fn explain_enode_existance( &self, node: Id, - rest_of_proof: Rc>, + left_to_prove: Rc>, cache: &mut ExplainCache, enode_cache: &mut NodeExplanationCache, ) -> TreeExplanation { @@ -1268,7 +1395,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { let existance_node = &self.explainfind[usize::from(existance)]; // case 1) if existance == node { - return vec![self.node_to_explanation(node, enode_cache), rest_of_proof]; + return vec![self.expr_reconstruction(node, enode_cache), left_to_prove]; } // case 2) @@ -1287,13 +1414,15 @@ impl<'x, L: Language> ExplainNodes<'x, L> { } let adj = self.explain_adjacent(connection, cache, enode_cache, false); + let adj_exp_text = adj.to_string(); + trace!("Adjacent explanation: {adj_exp_text}"); let mut exp = self.explain_enode_existance(existance, adj, cache, enode_cache); - exp.push(rest_of_proof); + exp.push(left_to_prove); return exp; } // case 3) - let mut new_rest_of_proof = (*self.node_to_explanation(existance, enode_cache)).clone(); + let mut new_rest_of_proof = (*self.expr_reconstruction(existance, enode_cache)).clone(); let mut index_of_child = 0; let mut found = false; self.node(existance).for_each(|child| { @@ -1307,7 +1436,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { } }); assert!(found); - new_rest_of_proof.child_proofs[index_of_child].push(rest_of_proof); + new_rest_of_proof.child_proofs[index_of_child].push(left_to_prove); self.explain_enode_existance(existance, Rc::new(new_rest_of_proof), cache, enode_cache) } @@ -1320,7 +1449,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { node_explanation_cache: &mut NodeExplanationCache, use_unoptimized: bool, ) -> TreeExplanation { - let mut proof = vec![self.node_to_explanation(left, node_explanation_cache)]; + let mut proof = vec![self.expr_reconstruction(left, node_explanation_cache)]; let (left_connections, right_connections) = if use_unoptimized { self.get_path_unoptimized(left, right) } else { @@ -1364,7 +1493,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { let term = match connection.justification { Justification::Rule(name) => { let mut rewritten = - (*self.node_to_explanation(connection.next, node_explanation_cache)).clone(); + (*self.expr_reconstruction(connection.next, node_explanation_cache)).clone(); if connection.is_rewrite_forward { rewritten.forward_rule = Some(name); } else { @@ -1946,7 +2075,10 @@ impl<'x, L: Language> ExplainNodes<'x, L> { #[cfg(test)] mod tests { + use env_logger::try_init; use super::super::*; + use itertools::*; + use log::LevelFilter::Trace; #[test] fn simple_explain() { @@ -2065,7 +2197,6 @@ mod tests { .collect(); let mut egraph = Runner::default() .with_explanations_enabled() - .without_explanation_length_optimization() .with_expr(&"a".parse().unwrap()) .run(&rws) .egraph; @@ -2082,6 +2213,97 @@ mod tests { exp.make_flat_explanation() ); } + + #[test] + fn long_explain_exists() { + //! Same as previous test, but now I want to make a rewrite add some term and see it exists in + //! more then one step + use crate::SymbolLang; + init_logger(); + + let rws: Vec> = + [rewrite!("makeb"; "a" => "b"), rewrite!("makec"; "b" => "c"), rewrite!("revmakec"; "d" => "e"), multi_rewrite!("makef"; "?y = e, ?x = c" => "?y = f")] + .iter() + .cloned() + .collect(); + let mut egraph = Runner::default() + .with_explanations_enabled() + .with_expr(&"a".parse().unwrap()) + .with_expr(&"d".parse().unwrap()) + .run(&rws) + .egraph; + egraph.rebuild(); + let mut exp = egraph.explain_existance(&"f".parse().unwrap()); + println!("{:?}", exp.make_flat_explanation()); + assert_eq!( + exp.make_flat_explanation().len(), + 4, + "Expected 4 steps, got {:?}", + exp.make_flat_explanation() + ); + } + #[test] + fn explain_existance_advanced() { + env_logger::builder().filter_level(Trace).build(); + let mut egraph = EGraph::::default().with_explanations_enabled(); + // 2024-06-25T21:52:39.527646+03:00 - Adding precondition: (cbugs_2eBug01_2eEytanNat___24t cbugs_2eBug01_2eEYTANO) + // 2024-06-25T21:52:39.691301+03:00 - Adding axiom $_typeof_bugs.Bug01.EYTANS as const: (c_24__type__441204___24t cbugs_2eBug01_2eEYTANS) + let nat0 = egraph.add_expr(&"(cbugs_2eBug01_2eEytanNat___24t cbugs_2eBug01_2eEYTANO)".parse().unwrap()); + let ts = egraph.add_expr(&"(c_24__type__441204___24t cbugs_2eBug01_2eEYTANS)".parse().unwrap()); + + // 2024-06-25T21:52:39.692953+03:00 - Creating rewrite for rev_$_type_441204: ?specialty_encoding_x3 = (cbugs_2eBug01_2eEytanNat___24t ?Vv__CANONICAL__0), ?specialty_encoding_x4 = (cbugs_2eBug01_2eEytanNat___24t (happ ?Vvar__441205 ?Vv__CANONICAL__0)) => ?new = (c_24__type__441204___24t ?Vvar__441205) + // 2024-06-25T21:52:39.692708+03:00 - Creating rewrite for $_type_441204: ?specialty_encoding_x0 = (cbugs_2eBug01_2eEytanNat___24t ?Vv__CANONICAL__0), ?specialty_encoding_x1 = (c_24__type__441204___24t ?Vvar__441205) => ?new = (cbugs_2eBug01_2eEytanNat___24t (happ ?Vvar__441205 ?Vv__CANONICAL__0)) + // 2024-06-25T21:52:39.692276+03:00 - Creating rewrite for rev_$tdef_$_type_441204: ?specialty_encoding_x2 = (t ?V_24Y c_24__type__441204) => ?new = (c_24__type__441204___24t ?V_24Y) + // 2024-06-25T21:52:39.692123+03:00 - Creating rewrite for $tdef_$_type_441204: ?specialty_encoding_x0 = (c_24__type__441204___24t ?V_24Y) => ?new = (t ?V_24Y c_24__type__441204) + // 2024-06-25T21:52:39.691733+03:00 - Creating eq rewrite for (lhs => rhs) $adef_bugs.Bug01.EYTANS_$a1: all lhs := ?specialty_encoding_eq = (happ cbugs_2eBug01_2eEYTANS ?V_24X1) => rhs := ?specialty_encoding_eq = (cbugs_2eBug01_2eEYTANS___24a1 ?V_24X1) + // 2024-06-25T21:52:39.691780+03:00 - (No need) Creating eq rewrite for (rhs => lhs) rev_$adef_bugs.Bug01.EYTANS_$a1: all conds := ?specialty_encoding_eq = (cbugs_2eBug01_2eEYTANS___24a1 ?V_24X1) => lhs := ?specialty_encoding_eq = (happ cbugs_2eBug01_2eEYTANS ?V_24X1) + // 2024-06-25T21:52:39.691160+03:00 - Creating eq rewrite for $_inj_bugs.Bug01.EYTANS (both sides): ?specialty_encoding_x0 = (cbugs_2eBug01_2eEYTANS___24a1 ?Vvar__0___24Anonymous__441202), ?specialty_encoding_x0 = (cbugs_2eBug01_2eEYTANS___24a1 ?Vvar__0___24Anonymous__441203) => lhs := ?specialty_encoding_eq = ?Vvar__0___24Anonymous__441202, rhs := ?specialty_encoding_eq = ?Vvar__0___24Anonymous__441203 + let rewrites: Vec> = vec![ + multi_rewrite!("rev_$_type_441204"; "?specialty_encoding_x3 = (cbugs_2eBug01_2eEytanNat___24t ?Vv__CANONICAL__0), ?specialty_encoding_x4 = (cbugs_2eBug01_2eEytanNat___24t (happ ?Vvar__441205 ?Vv__CANONICAL__0))" => "?new = (c_24__type__441204___24t ?Vvar__441205)"), + multi_rewrite!("$_type_441204"; "?specialty_encoding_x0 = (cbugs_2eBug01_2eEytanNat___24t ?Vv__CANONICAL__0), ?specialty_encoding_x1 = (c_24__type__441204___24t ?Vvar__441205)" => "?new = (cbugs_2eBug01_2eEytanNat___24t (happ ?Vvar__441205 ?Vv__CANONICAL__0))"), + multi_rewrite!("rev_$tdef_$_type_441204"; "?specialty_encoding_x2 = (t ?V_24Y c_24__type__441204)" => "?new = (c_24__type__441204___24t ?V_24Y)"), + multi_rewrite!("$tdef_$_type_441204"; "?specialty_encoding_x0 = (c_24__type__441204___24t ?V_24Y)" => "?new = (t ?V_24Y c_24__type__441204)"), + multi_rewrite!("$adef_bugs.Bug01.EYTANS_$a1"; "?specialty_encoding_eq = (happ cbugs_2eBug01_2eEYTANS ?V_24X1)" => "?specialty_encoding_eq = (cbugs_2eBug01_2eEYTANS___24a1 ?V_24X1)"), + multi_rewrite!("$_inj_bugs.Bug01.EYTANS"; "?specialty_encoding_x0 = (cbugs_2eBug01_2eEYTANS___24a1 ?Vvar__0___24Anonymous__441202), ?specialty_encoding_x0 = (cbugs_2eBug01_2eEYTANS___24a1 ?Vvar__0___24Anonymous__441203)" => "?specialty_encoding_eq = ?Vvar__0___24Anonymous__441202, ?specialty_encoding_eq = ?Vvar__0___24Anonymous__441203"), + ]; + + let mut egraph = Runner::default() + .with_explanations_enabled() + .with_egraph(egraph) + .run(&rewrites) + .egraph; + let nat1_exp: RecExpr = "(cbugs_2eBug01_2eEytanNat___24t (cbugs_2eBug01_2eEYTANS___24a1 cbugs_2eBug01_2eEYTANO))".parse().unwrap(); + let nat1_lookup = egraph.lookup_expr(&nat1_exp).unwrap(); + println!("Found nat1 @ {}", nat1_lookup); + // println!("Adding nat 1"); + // let nat1 = egraph.add_expr_uncanonical(&nat1_exp); + println!("Adding nat hap"); + let nathap = egraph.add_expr_uncanonical(&"(cbugs_2eBug01_2eEytanNat___24t (happ cbugs_2eBug01_2eEYTANS cbugs_2eBug01_2eEYTANO))".parse().unwrap()); + println!("Adding s0"); + let s0id = egraph.add_expr_uncanonical(&"(ßcbugs_2eBug01_2eEYTANS___24a1 cbugs_2eBug01_2eEYTANO)".parse().unwrap()); + println!("Adding hap s"); + let happsid = egraph.add_expr_uncanonical(&"(happ cbugs_2eBug01_2eEYTANS, cbugs_2eBug01_2eEYTANO)".parse().unwrap()); + let mut exp = egraph.explain_existance(&"(cbugs_2eBug01_2eEytanNat___24t (cbugs_2eBug01_2eEYTANS___24a1 cbugs_2eBug01_2eEYTANO))".parse().unwrap()); + exp.make_flat_explanation(); + // Added nat 0 + // added type s + // $_type_441204 (now we have nat (hap s 0)) + // $adef_bugs.Bug01.EYTANS_$a1 goal + println!("{}", exp.make_flat_explanation().iter() + .rev() + .filter_map(|x| { + x.backward_rule + .or(x.forward_rule) + .map(|s| format!("'{}'", s)) + }) + .join("\n")); + assert_eq!( + exp.make_flat_explanation().len(), + 3, + "Expected 3 steps, got \n{}", + exp.make_flat_explanation().iter().join("\n") + ); + } } #[test] @@ -2102,5 +2324,6 @@ fn simple_explain_union_trusted() { egraph.union_trusted(d, fb, "d=fb"); egraph.rebuild(); let mut exp = egraph.explain_equivalence(&"c".parse().unwrap(), &"d".parse().unwrap()); + println!("{:?}", exp.make_flat_explanation()); assert_eq!(exp.make_flat_explanation().len(), 4) } diff --git a/src/language.rs b/src/language.rs index 96890e48..4f0378d0 100644 --- a/src/language.rs +++ b/src/language.rs @@ -28,7 +28,7 @@ use thiserror::Error; /// /// See [`SymbolLang`] for quick-and-dirty use cases. #[allow(clippy::len_without_is_empty)] -pub trait Language: Debug + Clone + Eq + Ord + Hash { +pub trait Language: Debug + Clone + Eq + Ord + Hash + Display { /// Type representing the cases of this language. /// /// Used for short-circuiting the search for equivalent nodes. diff --git a/src/multipattern.rs b/src/multipattern.rs index 7b6d2348..4f805125 100644 --- a/src/multipattern.rs +++ b/src/multipattern.rs @@ -2,6 +2,7 @@ use std::str::FromStr; use thiserror::Error; use crate::*; +use crate::pattern::apply_pat; /// A set of open expressions bound to variables. /// @@ -173,12 +174,23 @@ impl> Applier for MultiPattern { let mut id_buf = vec![]; for (i, (v, p)) in self.asts.iter().enumerate() { id_buf.resize(p.as_ref().len(), 0.into()); - let id1 = crate::pattern::apply_pat(&mut id_buf, p.as_ref(), egraph, &subst); - if let Some(id2) = subst.insert(*v, id1) { - egraph.union(id1, id2, Some(rule_name)); - } - if i == 0 { - added.push(id1) + if egraph.are_explanations_enabled() { + // If we have a union to do we can use union instantiations, otherwise + // we just need to add uncanonical with reason (maybe apply_pat?) + let (id, did_something) = { + egraph.union_instantiations(p, &self.ast, subst, rule_name) + }; + if did_something { + added.push(id) + } + } else { + let id1 = crate::pattern::apply_pat(&mut id_buf, p.as_ref(), egraph, &subst); + if let Some(id2) = subst.insert(*v, id1) { + egraph.union(id1, id2, Some(rule_name)); + } + if i == 0 { + added.push(id1) + } } } } diff --git a/src/run.rs b/src/run.rs index a7762fee..8cb6643b 100644 --- a/src/run.rs +++ b/src/run.rs @@ -577,6 +577,7 @@ where let rebuild_time = Instant::now(); let n_rebuilds = self.egraph.rebuild(); + #[cfg(feature = "check_proof")] if self.egraph.are_explanations_enabled() { debug_assert!(self.egraph.check_each_explain(rules)); } diff --git a/src/test.rs b/src/test.rs index 4784e816..9308d1fa 100644 --- a/src/test.rs +++ b/src/test.rs @@ -102,6 +102,7 @@ pub fn test_runner( explained.get_string_with_let(); let flattened = explained.make_flat_explanation().clone(); let vanilla_len = flattened.len(); + #[cfg(feature = "check_proof")] explained.check_proof(rules); assert!(!explained.get_tree_size().is_zero()); @@ -111,6 +112,7 @@ pub fn test_runner( let short_len = explained_short.get_flat_strings().len(); assert!(short_len <= vanilla_len); assert!(!explained_short.get_tree_size().is_zero()); + #[cfg(feature = "check_proof")] explained_short.check_proof(rules); } }