diff --git a/src/egraph.rs b/src/egraph.rs index b8688153..4fad79de 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -231,8 +231,8 @@ impl> EGraph { let right_unions = other.get_union_equalities(); for (left, right, why) in right_unions { self.union_instantiations( - &other.id_to_pattern(left, &Default::default()).0.ast, - &other.id_to_pattern(right, &Default::default()).0.ast, + &other.id_to_pattern(left, |_| None).0.ast, + &other.id_to_pattern(right, |_| None).0.ast, &Default::default(), why, ); @@ -376,11 +376,15 @@ impl> EGraph { /// When an eclass listed in the given substitutions is found, it creates a variable. /// It also adds this variable and the corresponding Id value to the resulting [`Subst`] /// Otherwise it behaves like [`id_to_expr`](EGraph::id_to_expr). - pub fn id_to_pattern(&self, id: Id, substitutions: &HashMap) -> (Pattern, Subst) { + pub fn id_to_pattern( + &self, + id: Id, + mut substitutions: impl FnMut(Id) -> Option, + ) -> (Pattern, Subst) { let mut res = Default::default(); let mut subst = Default::default(); let mut cache = Default::default(); - self.id_to_pattern_internal(&mut res, id, substitutions, &mut subst, &mut cache); + self.id_to_pattern_internal(&mut res, id, &mut substitutions, &mut subst, &mut cache); (Pattern::new(res), subst) } @@ -388,16 +392,16 @@ impl> EGraph { &self, res: &mut PatternAst, node_id: Id, - var_substitutions: &HashMap, + var_substitutions: &mut impl FnMut(Id) -> Option, subst: &mut Subst, cache: &mut HashMap, ) -> Id { if let Some(existing) = cache.get(&node_id) { return *existing; } - let res_id = if let Some(existing) = var_substitutions.get(&node_id) { + let res_id = if let Some(existing) = var_substitutions(node_id) { let var = format!("?{}", node_id).parse().unwrap(); - subst.insert(var, *existing); + subst.insert(var, existing); res.add(ENodeOrVar::Var(var)) } else { let new_node = self.id_to_node(node_id).clone().map_children(|child| {