diff --git a/src/egraph.rs b/src/egraph.rs index c832901e..1677c077 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -862,10 +862,9 @@ impl> EGraph { /// /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled pub fn add_expr_uncanonical(&mut self, expr: &RecExpr) -> Id { - let nodes = expr.as_ref(); - let mut new_ids = Vec::with_capacity(nodes.len()); - let mut new_node_q = Vec::with_capacity(nodes.len()); - for node in nodes { + let mut new_ids = Vec::with_capacity(expr.len()); + let mut new_node_q = Vec::with_capacity(expr.len()); + for node in expr { let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]); let size_before = self.unionfind.size(); let next_id = self.add_uncanonical(new_node); @@ -901,10 +900,9 @@ impl> EGraph { /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an correspond to the /// instantiation of the pattern fn add_instantiation_noncanonical(&mut self, pat: &PatternAst, subst: &Subst) -> Id { - let nodes = pat.as_ref(); - let mut new_ids = Vec::with_capacity(nodes.len()); - let mut new_node_q = Vec::with_capacity(nodes.len()); - for node in nodes { + let mut new_ids = Vec::with_capacity(pat.len()); + let mut new_node_q = Vec::with_capacity(pat.len()); + for node in pat { match node { ENodeOrVar::Var(var) => { let id = self.find(subst[*var]); @@ -986,9 +984,8 @@ impl> EGraph { /// Lookup the eclasses of all the nodes in the given [`RecExpr`]. pub fn lookup_expr_ids(&self, expr: &RecExpr) -> Option> { - let nodes = expr.as_ref(); - let mut new_ids = Vec::with_capacity(nodes.len()); - for node in nodes { + let mut new_ids = Vec::with_capacity(expr.len()); + for node in expr { let node = node.clone().map_children(|i| new_ids[usize::from(i)]); let id = self.lookup(node)?; new_ids.push(id) @@ -1118,8 +1115,8 @@ impl> EGraph { /// In most cases, there will none or exactly one id. /// pub fn equivs(&self, expr1: &RecExpr, expr2: &RecExpr) -> Vec { - let pat1 = Pattern::from(expr1.as_ref()); - let pat2 = Pattern::from(expr2.as_ref()); + let pat1 = Pattern::from(expr1); + let pat2 = Pattern::from(expr2); let matches1 = pat1.search(self); trace!("Matches1: {:?}", matches1); diff --git a/src/explain.rs b/src/explain.rs index ab2fca52..cbc89f10 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -792,11 +792,9 @@ impl FlatTerm { /// Rewrite the FlatTerm by matching the lhs and substituting the rhs. /// The lhs must be guaranteed to match. pub fn rewrite(&self, lhs: &PatternAst, rhs: &PatternAst) -> FlatTerm { - let lhs_nodes = lhs.as_ref(); - let rhs_nodes = rhs.as_ref(); let mut bindings = Default::default(); - self.make_bindings(lhs_nodes, lhs_nodes.len() - 1, &mut bindings); - FlatTerm::from_pattern(rhs_nodes, rhs_nodes.len() - 1, &bindings) + self.make_bindings(lhs, lhs.len() - 1, &mut bindings); + FlatTerm::from_pattern(rhs, rhs.len() - 1, &bindings) } /// Checks if this term or any child has a [`forward_rule`](FlatTerm::forward_rule). diff --git a/src/extract.rs b/src/extract.rs index b2ca9b16..762359ee 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -134,14 +134,13 @@ pub trait CostFunction { /// down the [`RecExpr`]. /// fn cost_rec(&mut self, expr: &RecExpr) -> Self::Cost { - let nodes = expr.as_ref(); - let mut costs = hashmap_with_capacity::(nodes.len()); - for (i, node) in nodes.iter().enumerate() { + let mut costs = hashmap_with_capacity::(expr.len()); + for (i, node) in expr.items() { let cost = self.cost(node, |i| costs[&i].clone()); - costs.insert(Id::from(i), cost); + costs.insert(i, cost); } - let last_id = Id::from(expr.as_ref().len() - 1); - costs[&last_id].clone() + let root = expr.root(); + costs[&root].clone() } } diff --git a/src/language.rs b/src/language.rs index 7a2f9dbc..6ccecefd 100644 --- a/src/language.rs +++ b/src/language.rs @@ -1,4 +1,6 @@ -use std::ops::{BitOr, Index, IndexMut}; +use std::borrow::{Borrow, BorrowMut}; +use std::iter::FromIterator; +use std::ops::{BitOr, Deref, DerefMut, Index, IndexMut}; use std::{cmp::Ordering, convert::TryFrom}; use std::{ convert::Infallible, @@ -213,10 +215,10 @@ pub trait Language: Debug + Clone + Eq + Ord + Hash { } } - // finally, add the root node and create the expression - let mut nodes: Vec = set.into_iter().collect(); - nodes.push(self.clone().map_children(|id| ids[&id])); - Ok(RecExpr::from(nodes)) + // finally, create the expression and add the root node + let mut expr: RecExpr<_> = set.into_iter().collect(); + expr.add(self.clone().map_children(|id| ids[&id])); + Ok(expr) } } @@ -398,12 +400,44 @@ impl Default for RecExpr { } } +impl Borrow<[L]> for RecExpr { + fn borrow(&self) -> &[L] { + &self.nodes + } +} + +impl BorrowMut<[L]> for RecExpr { + fn borrow_mut(&mut self) -> &mut [L] { + &mut self.nodes + } +} + +impl Deref for RecExpr { + type Target = [L]; + + fn deref(&self) -> &Self::Target { + self.borrow() + } +} + +impl DerefMut for RecExpr { + fn deref_mut(&mut self) -> &mut Self::Target { + self.borrow_mut() + } +} + impl AsRef<[L]> for RecExpr { fn as_ref(&self) -> &[L] { &self.nodes } } +impl AsMut<[L]> for RecExpr { + fn as_mut(&mut self) -> &mut [L] { + &mut self.nodes + } +} + impl From> for RecExpr { fn from(nodes: Vec) -> Self { Self { nodes } @@ -416,22 +450,28 @@ impl From> for Vec { } } +impl FromIterator for RecExpr { + fn from_iter>(iter: T) -> Self { + Self::from(iter.into_iter().collect::>()) + } +} + impl RecExpr { /// Adds a given enode to this `RecExpr`. - /// The enode's children `Id`s must refer to elements already in this list. + /// The enode's children [`Id`]s must refer to elements already in this list. pub fn add(&mut self, node: L) -> Id { debug_assert!( - node.all(|id| usize::from(id) < self.nodes.len()), + node.all(|id| id <= self.root()), "node {:?} has children not in this expr: {:?}", node, self ); self.nodes.push(node); - Id::from(self.nodes.len() - 1) + self.root() } pub(crate) fn compact(mut self) -> Self { - let mut ids = hashmap_with_capacity::(self.nodes.len()); + let mut ids = hashmap_with_capacity::(self.len()); let mut set = IndexSet::default(); for (i, node) in self.nodes.drain(..).enumerate() { let node = node.map_children(|id| ids[&id]); @@ -446,21 +486,31 @@ impl RecExpr { self[new_root].build_recexpr(|id| self[id].clone()) } + /// Returns an iterator over the [`Id`]s in this expression. + pub fn ids(&self) -> impl ExactSizeIterator + DoubleEndedIterator { + (0..self.len()).map(Id::from) + } + + /// Returns an iterator over the [`Id`]s and enodes of this expression. + pub fn items(&self) -> impl ExactSizeIterator + DoubleEndedIterator { + self.ids().zip(self) + } + + /// Returns an iterator over the [`Id`]s and enodes of this expression. + pub fn items_mut( + &mut self, + ) -> impl ExactSizeIterator + DoubleEndedIterator { + self.ids().zip(self) + } + /// Checks if this expr is a DAG, i.e. doesn't have any back edges pub fn is_dag(&self) -> bool { - for (i, n) in self.nodes.iter().enumerate() { - for &child in n.children() { - if usize::from(child) >= i { - return false; - } - } - } - true + self.items().all(|(id, n)| n.all(|child| child < id)) } /// Get the root node of this expression. When adding a new node via `add`, it becomes the new root. pub fn root(&self) -> Id { - Id::from(self.nodes.len() - 1) + self.ids().last().unwrap() } } @@ -477,6 +527,33 @@ impl IndexMut for RecExpr { } } +impl IntoIterator for RecExpr { + type Item = L; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.nodes.into_iter() + } +} + +impl<'a, L> IntoIterator for &'a RecExpr { + type Item = &'a L; + type IntoIter = std::slice::Iter<'a, L>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a, L> IntoIterator for &'a mut RecExpr { + type Item = &'a mut L; + type IntoIter = std::slice::IterMut<'a, L>; + + fn into_iter(self) -> Self::IntoIter { + self.iter_mut() + } +} + impl Display for RecExpr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.nodes.is_empty() { diff --git a/src/lp_extract.rs b/src/lp_extract.rs index 5a4057be..7424e9a5 100644 --- a/src/lp_extract.rs +++ b/src/lp_extract.rs @@ -52,7 +52,7 @@ impl> LpCostFunction for AstSize { /// // Using ILP only counts common sub-expressions once, /// // so it can lead to a smaller DAG expression. /// assert_eq!(lp_best.to_string(), "(f x x x)"); -/// assert_eq!(lp_best.as_ref().len(), 2); +/// assert_eq!(lp_best.len(), 2); /// ``` #[cfg_attr(docsrs, doc(cfg(feature = "lp")))] pub struct LpExtractor<'a, L: Language, N: Analysis> { @@ -260,7 +260,7 @@ mod tests { let (exp, ids) = ext.solve_multiple(&[f, g]); println!("{:?}", exp); println!("{}", exp); - assert_eq!(exp.as_ref().len(), 4); + assert_eq!(exp.len(), 4); assert_eq!(ids.len(), 2); } } diff --git a/src/machine.rs b/src/machine.rs index 54f55a3f..65f23ff8 100644 --- a/src/machine.rs +++ b/src/machine.rs @@ -142,11 +142,11 @@ impl Compiler { } fn load_pattern(&mut self, pattern: &PatternAst) { - let len = pattern.as_ref().len(); + let len = pattern.len(); self.free_vars = Vec::with_capacity(len); self.subtree_size = Vec::with_capacity(len); - for node in pattern.as_ref() { + for node in pattern { let mut free = HashSet::default(); let mut size = 0; match node { @@ -199,7 +199,7 @@ impl Compiler { fn compile(&mut self, patternbinder: Option, pattern: &PatternAst) { self.load_pattern(pattern); - let last_i = pattern.as_ref().len() - 1; + let root = pattern.root(); let mut next_out = self.next_reg; @@ -211,13 +211,13 @@ impl Compiler { comp.instructions .push(Instruction::Scan { out: comp.next_reg }); } - comp.add_todo(pattern, Id::from(last_i), comp.next_reg); + comp.add_todo(pattern, root, comp.next_reg); }; if let Some(v) = patternbinder { if let Some(&i) = self.v2r.get(&v) { // patternbinder already bound - self.add_todo(pattern, Id::from(last_i), i); + self.add_todo(pattern, root, i); } else { // patternbinder is new variable next_out.0 += 1; @@ -236,7 +236,6 @@ impl Compiler { self.instructions.push(Instruction::Lookup { i: reg, term: extracted - .as_ref() .iter() .map(|n| match n { ENodeOrVar::ENode(n) => ENodeOrReg::ENode(n.clone()), diff --git a/src/multipattern.rs b/src/multipattern.rs index 4a13a170..c1ad9ed4 100644 --- a/src/multipattern.rs +++ b/src/multipattern.rs @@ -110,7 +110,7 @@ impl> Searcher for MultiPattern { match self.asts.as_slice() { [] => panic!("empty multipattern"), [(_var, pat), ..] => { - if let [ENodeOrVar::Var(_)] = pat.as_ref() { + if let [ENodeOrVar::Var(_)] = **pat { panic!( "Bare cannot be first pattern variable in multipattern: {:?}", self.asts @@ -134,7 +134,7 @@ impl> Searcher for MultiPattern { let mut vars = vec![]; for (v, pat) in &self.asts { vars.push(*v); - for n in pat.as_ref() { + for n in pat { if let ENodeOrVar::Var(v) = n { vars.push(*v) } @@ -172,8 +172,8 @@ impl> Applier for MultiPattern { let mut subst = subst.clone(); let mut id_buf = vec![]; for (i, (v, p)) in self.asts.iter().enumerate() { - id_buf.resize(p.as_ref().len(), 0.into()); - let id1 = crate::pattern::apply_pat(&mut id_buf, p.as_ref(), egraph, &subst); + id_buf.resize(p.len(), 0.into()); + let id1 = crate::pattern::apply_pat(&mut id_buf, p, egraph, &subst); if let Some(id2) = subst.insert(*v, id1) { egraph.union(id1, id2); } @@ -190,7 +190,7 @@ impl> Applier for MultiPattern { let mut bound_vars = HashSet::default(); let mut vars = vec![]; for (bv, pat) in &self.asts { - for n in pat.as_ref() { + for n in pat { if let ENodeOrVar::Var(v) = n { // using vars that are already bound doesn't count if !bound_vars.contains(v) { diff --git a/src/pattern.rs b/src/pattern.rs index 283c7878..047108a0 100644 --- a/src/pattern.rs +++ b/src/pattern.rs @@ -1,6 +1,7 @@ use fmt::Formatter; use log::*; use std::borrow::Cow; +use std::convert::TryInto; use std::fmt::{self, Display}; use std::{convert::TryFrom, str::FromStr}; @@ -86,7 +87,7 @@ impl PatternAst { } } - for n in self.as_ref() { + for n in self { new.add(match n { ENodeOrVar::ENode(_) => n.clone(), ENodeOrVar::Var(v) => { @@ -111,7 +112,7 @@ impl Pattern { /// Returns a list of the [`Var`]s in this pattern. pub fn vars(&self) -> Vec { let mut vars = vec![]; - for n in self.ast.as_ref() { + for n in &self.ast { if let ENodeOrVar::Var(v) = n { if !vars.contains(v) { vars.push(*v) @@ -225,8 +226,14 @@ impl std::str::FromStr for Pattern { impl<'a, L: Language> From<&'a [L]> for Pattern { fn from(expr: &'a [L]) -> Self { - let nodes: Vec<_> = expr.iter().cloned().map(ENodeOrVar::ENode).collect(); - let ast = RecExpr::from(nodes); + let ast = expr.iter().cloned().map(ENodeOrVar::ENode).collect(); + Self::new(ast) + } +} + +impl From> for Pattern { + fn from(expr: RecExpr) -> Self { + let ast = expr.into_iter().map(ENodeOrVar::ENode).collect(); Self::new(ast) } } @@ -243,17 +250,22 @@ impl From> for Pattern { } } -impl TryFrom> for RecExpr { +impl TryFrom> for RecExpr { type Error = Var; - fn try_from(pat: Pattern) -> Result { - let nodes = pat.ast.as_ref().iter().cloned(); - let ns: Result, _> = nodes + fn try_from(ast: PatternAst) -> Result { + ast.into_iter() .map(|n| match n { ENodeOrVar::ENode(n) => Ok(n), ENodeOrVar::Var(v) => Err(v), }) - .collect(); - ns.map(RecExpr::from) + .collect() + } +} + +impl TryFrom> for RecExpr { + type Error = Var; + fn try_from(pat: Pattern) -> Result { + pat.ast.try_into() } } @@ -286,7 +298,7 @@ impl> Searcher for Pattern { } fn search_with_limit(&self, egraph: &EGraph, limit: usize) -> Vec> { - match self.ast.as_ref().last().unwrap() { + match self.ast.last().unwrap() { ENodeOrVar::ENode(e) => { let key = e.discriminant(); match egraph.classes_for_op(&key) { @@ -343,8 +355,7 @@ where rule_name: Symbol, ) -> Vec { let mut added = vec![]; - let ast = self.ast.as_ref(); - let mut id_buf = vec![0.into(); ast.len()]; + let mut id_buf = vec![0.into(); self.ast.len()]; for mat in matches { let sast = mat.ast.as_ref().map(|cow| cow.as_ref()); for subst in &mat.substs { @@ -356,7 +367,7 @@ where did_something = did_something_temp; id = id_temp; } else { - id = apply_pat(&mut id_buf, ast, egraph, subst); + id = apply_pat(&mut id_buf, &self.ast, egraph, subst); did_something = egraph.union(id, mat.eclass); } @@ -376,9 +387,8 @@ where searcher_ast: Option<&PatternAst>, rule_name: Symbol, ) -> Vec { - let ast = self.ast.as_ref(); - let mut id_buf = vec![0.into(); ast.len()]; - let id = apply_pat(&mut id_buf, ast, egraph, subst); + let mut id_buf = vec![0.into(); self.ast.len()]; + let id = apply_pat(&mut id_buf, &self.ast, egraph, subst); if let Some(ast) = searcher_ast { let (from, did_something) = diff --git a/src/rewrite.rs b/src/rewrite.rs index 1a3ea5f4..0dfae582 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -524,10 +524,10 @@ where N: Analysis, { fn check(&self, egraph: &mut EGraph, _eclass: Id, subst: &Subst) -> bool { - let mut id_buf_1 = vec![0.into(); self.p1.ast.as_ref().len()]; - let mut id_buf_2 = vec![0.into(); self.p2.ast.as_ref().len()]; - let a1 = apply_pat(&mut id_buf_1, self.p1.ast.as_ref(), egraph, subst); - let a2 = apply_pat(&mut id_buf_2, self.p2.ast.as_ref(), egraph, subst); + let mut id_buf_1 = vec![0.into(); self.p1.ast.len()]; + let mut id_buf_2 = vec![0.into(); self.p2.ast.len()]; + let a1 = apply_pat(&mut id_buf_1, &self.p1.ast, egraph, subst); + let a2 = apply_pat(&mut id_buf_2, &self.p2.ast, egraph, subst); a1 == a2 } diff --git a/src/test.rs b/src/test.rs index 4784e816..8565f79a 100644 --- a/src/test.rs +++ b/src/test.rs @@ -155,10 +155,10 @@ where eprintln!("{} patterns", patterns.len()); - patterns.retain(|p| p.ast.as_ref().len() > 1); + patterns.retain(|p| p.ast.len() > 1); patterns.sort_by_key(|p| p.to_string()); patterns.dedup(); - patterns.sort_by_key(|p| p.ast.as_ref().len()); + patterns.sort_by_key(|p| p.ast.len()); let iter_limit = env_var("EGG_ITER_LIMIT").unwrap_or(1); let node_limit = env_var("EGG_NODE_LIMIT").unwrap_or(1_000_000); diff --git a/tests/math.rs b/tests/math.rs index a0d8c07a..89fa80fb 100644 --- a/tests/math.rs +++ b/tests/math.rs @@ -350,12 +350,12 @@ fn math_lp_extract() { let best = Extractor::new(&runner.egraph, AstSize).find_best(root).1; let lp_best = LpExtractor::new(&runner.egraph, AstSize).solve(root); - println!("input [{}] {}", expr.as_ref().len(), expr); - println!("normal [{}] {}", best.as_ref().len(), best); - println!("ilp cse [{}] {}", lp_best.as_ref().len(), lp_best); + println!("input [{}] {}", expr.len(), expr); + println!("normal [{}] {}", best.len(), best); + println!("ilp cse [{}] {}", lp_best.len(), lp_best); assert_ne!(best, lp_best); - assert_eq!(lp_best.as_ref().len(), 4); + assert_eq!(lp_best.len(), 4); } #[test]