Skip to content

Commit

Permalink
Add deref, borrow, and iter methods of slice to RecExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
mtak- authored Dec 30, 2024
1 parent 319b9c4 commit fa40e29
Show file tree
Hide file tree
Showing 11 changed files with 161 additions and 81 deletions.
23 changes: 10 additions & 13 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -862,10 +862,9 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
///
/// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled
pub fn add_expr_uncanonical(&mut self, expr: &RecExpr<L>) -> Id {
let nodes = expr.as_ref();
let mut new_ids = Vec::with_capacity(nodes.len());
let mut new_node_q = Vec::with_capacity(nodes.len());
for node in nodes {
let mut new_ids = Vec::with_capacity(expr.len());
let mut new_node_q = Vec::with_capacity(expr.len());
for node in expr {
let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]);
let size_before = self.unionfind.size();
let next_id = self.add_uncanonical(new_node);
Expand Down Expand Up @@ -901,10 +900,9 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an correspond to the
/// instantiation of the pattern
fn add_instantiation_noncanonical(&mut self, pat: &PatternAst<L>, subst: &Subst) -> Id {
let nodes = pat.as_ref();
let mut new_ids = Vec::with_capacity(nodes.len());
let mut new_node_q = Vec::with_capacity(nodes.len());
for node in nodes {
let mut new_ids = Vec::with_capacity(pat.len());
let mut new_node_q = Vec::with_capacity(pat.len());
for node in pat {
match node {
ENodeOrVar::Var(var) => {
let id = self.find(subst[*var]);
Expand Down Expand Up @@ -986,9 +984,8 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {

/// Lookup the eclasses of all the nodes in the given [`RecExpr`].
pub fn lookup_expr_ids(&self, expr: &RecExpr<L>) -> Option<Vec<Id>> {
let nodes = expr.as_ref();
let mut new_ids = Vec::with_capacity(nodes.len());
for node in nodes {
let mut new_ids = Vec::with_capacity(expr.len());
for node in expr {
let node = node.clone().map_children(|i| new_ids[usize::from(i)]);
let id = self.lookup(node)?;
new_ids.push(id)
Expand Down Expand Up @@ -1118,8 +1115,8 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// In most cases, there will none or exactly one id.
///
pub fn equivs(&self, expr1: &RecExpr<L>, expr2: &RecExpr<L>) -> Vec<Id> {
let pat1 = Pattern::from(expr1.as_ref());
let pat2 = Pattern::from(expr2.as_ref());
let pat1 = Pattern::from(expr1);
let pat2 = Pattern::from(expr2);
let matches1 = pat1.search(self);
trace!("Matches1: {:?}", matches1);

Expand Down
6 changes: 2 additions & 4 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,11 +792,9 @@ impl<L: Language> FlatTerm<L> {
/// Rewrite the FlatTerm by matching the lhs and substituting the rhs.
/// The lhs must be guaranteed to match.
pub fn rewrite(&self, lhs: &PatternAst<L>, rhs: &PatternAst<L>) -> FlatTerm<L> {
let lhs_nodes = lhs.as_ref();
let rhs_nodes = rhs.as_ref();
let mut bindings = Default::default();
self.make_bindings(lhs_nodes, lhs_nodes.len() - 1, &mut bindings);
FlatTerm::from_pattern(rhs_nodes, rhs_nodes.len() - 1, &bindings)
self.make_bindings(lhs, lhs.len() - 1, &mut bindings);
FlatTerm::from_pattern(rhs, rhs.len() - 1, &bindings)
}

/// Checks if this term or any child has a [`forward_rule`](FlatTerm::forward_rule).
Expand Down
11 changes: 5 additions & 6 deletions src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,13 @@ pub trait CostFunction<L: Language> {
/// down the [`RecExpr`].
///
fn cost_rec(&mut self, expr: &RecExpr<L>) -> Self::Cost {
let nodes = expr.as_ref();
let mut costs = hashmap_with_capacity::<Id, Self::Cost>(nodes.len());
for (i, node) in nodes.iter().enumerate() {
let mut costs = hashmap_with_capacity::<Id, Self::Cost>(expr.len());
for (i, node) in expr.items() {
let cost = self.cost(node, |i| costs[&i].clone());
costs.insert(Id::from(i), cost);
costs.insert(i, cost);
}
let last_id = Id::from(expr.as_ref().len() - 1);
costs[&last_id].clone()
let root = expr.root();
costs[&root].clone()
}
}

Expand Down
113 changes: 95 additions & 18 deletions src/language.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use std::ops::{BitOr, Index, IndexMut};
use std::borrow::{Borrow, BorrowMut};
use std::iter::FromIterator;
use std::ops::{BitOr, Deref, DerefMut, Index, IndexMut};
use std::{cmp::Ordering, convert::TryFrom};
use std::{
convert::Infallible,
Expand Down Expand Up @@ -213,10 +215,10 @@ pub trait Language: Debug + Clone + Eq + Ord + Hash {
}
}

// finally, add the root node and create the expression
let mut nodes: Vec<Self> = set.into_iter().collect();
nodes.push(self.clone().map_children(|id| ids[&id]));
Ok(RecExpr::from(nodes))
// finally, create the expression and add the root node
let mut expr: RecExpr<_> = set.into_iter().collect();
expr.add(self.clone().map_children(|id| ids[&id]));
Ok(expr)
}
}

Expand Down Expand Up @@ -398,12 +400,44 @@ impl<L> Default for RecExpr<L> {
}
}

impl<L> Borrow<[L]> for RecExpr<L> {
fn borrow(&self) -> &[L] {
&self.nodes
}
}

impl<L> BorrowMut<[L]> for RecExpr<L> {
fn borrow_mut(&mut self) -> &mut [L] {
&mut self.nodes
}
}

impl<L> Deref for RecExpr<L> {
type Target = [L];

fn deref(&self) -> &Self::Target {
self.borrow()
}
}

impl<L> DerefMut for RecExpr<L> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.borrow_mut()
}
}

impl<L> AsRef<[L]> for RecExpr<L> {
fn as_ref(&self) -> &[L] {
&self.nodes
}
}

impl<L> AsMut<[L]> for RecExpr<L> {
fn as_mut(&mut self) -> &mut [L] {
&mut self.nodes
}
}

impl<L> From<Vec<L>> for RecExpr<L> {
fn from(nodes: Vec<L>) -> Self {
Self { nodes }
Expand All @@ -416,22 +450,28 @@ impl<L> From<RecExpr<L>> for Vec<L> {
}
}

impl<L> FromIterator<L> for RecExpr<L> {
fn from_iter<T: IntoIterator<Item = L>>(iter: T) -> Self {
Self::from(iter.into_iter().collect::<Vec<_>>())
}
}

impl<L: Language> RecExpr<L> {
/// Adds a given enode to this `RecExpr`.
/// The enode's children `Id`s must refer to elements already in this list.
/// The enode's children [`Id`]s must refer to elements already in this list.
pub fn add(&mut self, node: L) -> Id {
debug_assert!(
node.all(|id| usize::from(id) < self.nodes.len()),
node.all(|id| id <= self.root()),
"node {:?} has children not in this expr: {:?}",
node,
self
);
self.nodes.push(node);
Id::from(self.nodes.len() - 1)
self.root()
}

pub(crate) fn compact(mut self) -> Self {
let mut ids = hashmap_with_capacity::<Id, Id>(self.nodes.len());
let mut ids = hashmap_with_capacity::<Id, Id>(self.len());
let mut set = IndexSet::default();
for (i, node) in self.nodes.drain(..).enumerate() {
let node = node.map_children(|id| ids[&id]);
Expand All @@ -446,21 +486,31 @@ impl<L: Language> RecExpr<L> {
self[new_root].build_recexpr(|id| self[id].clone())
}

/// Returns an iterator over the [`Id`]s in this expression.
pub fn ids(&self) -> impl ExactSizeIterator<Item = Id> + DoubleEndedIterator {
(0..self.len()).map(Id::from)
}

/// Returns an iterator over the [`Id`]s and enodes of this expression.
pub fn items(&self) -> impl ExactSizeIterator<Item = (Id, &L)> + DoubleEndedIterator {
self.ids().zip(self)
}

/// Returns an iterator over the [`Id`]s and enodes of this expression.
pub fn items_mut(
&mut self,
) -> impl ExactSizeIterator<Item = (Id, &mut L)> + DoubleEndedIterator {
self.ids().zip(self)
}

/// Checks if this expr is a DAG, i.e. doesn't have any back edges
pub fn is_dag(&self) -> bool {
for (i, n) in self.nodes.iter().enumerate() {
for &child in n.children() {
if usize::from(child) >= i {
return false;
}
}
}
true
self.items().all(|(id, n)| n.all(|child| child < id))
}

/// Get the root node of this expression. When adding a new node via `add`, it becomes the new root.
pub fn root(&self) -> Id {
Id::from(self.nodes.len() - 1)
self.ids().last().unwrap()
}
}

Expand All @@ -477,6 +527,33 @@ impl<L: Language> IndexMut<Id> for RecExpr<L> {
}
}

impl<L> IntoIterator for RecExpr<L> {
type Item = L;
type IntoIter = std::vec::IntoIter<L>;

fn into_iter(self) -> Self::IntoIter {
self.nodes.into_iter()
}
}

impl<'a, L> IntoIterator for &'a RecExpr<L> {
type Item = &'a L;
type IntoIter = std::slice::Iter<'a, L>;

fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}

impl<'a, L> IntoIterator for &'a mut RecExpr<L> {
type Item = &'a mut L;
type IntoIter = std::slice::IterMut<'a, L>;

fn into_iter(self) -> Self::IntoIter {
self.iter_mut()
}
}

impl<L: Language + Display> Display for RecExpr<L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.nodes.is_empty() {
Expand Down
4 changes: 2 additions & 2 deletions src/lp_extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl<L: Language, N: Analysis<L>> LpCostFunction<L, N> for AstSize {
/// // Using ILP only counts common sub-expressions once,
/// // so it can lead to a smaller DAG expression.
/// assert_eq!(lp_best.to_string(), "(f x x x)");
/// assert_eq!(lp_best.as_ref().len(), 2);
/// assert_eq!(lp_best.len(), 2);
/// ```
#[cfg_attr(docsrs, doc(cfg(feature = "lp")))]
pub struct LpExtractor<'a, L: Language, N: Analysis<L>> {
Expand Down Expand Up @@ -260,7 +260,7 @@ mod tests {
let (exp, ids) = ext.solve_multiple(&[f, g]);
println!("{:?}", exp);
println!("{}", exp);
assert_eq!(exp.as_ref().len(), 4);
assert_eq!(exp.len(), 4);
assert_eq!(ids.len(), 2);
}
}
11 changes: 5 additions & 6 deletions src/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ impl<L: Language> Compiler<L> {
}

fn load_pattern(&mut self, pattern: &PatternAst<L>) {
let len = pattern.as_ref().len();
let len = pattern.len();
self.free_vars = Vec::with_capacity(len);
self.subtree_size = Vec::with_capacity(len);

for node in pattern.as_ref() {
for node in pattern {
let mut free = HashSet::default();
let mut size = 0;
match node {
Expand Down Expand Up @@ -199,7 +199,7 @@ impl<L: Language> Compiler<L> {

fn compile(&mut self, patternbinder: Option<Var>, pattern: &PatternAst<L>) {
self.load_pattern(pattern);
let last_i = pattern.as_ref().len() - 1;
let root = pattern.root();

let mut next_out = self.next_reg;

Expand All @@ -211,13 +211,13 @@ impl<L: Language> Compiler<L> {
comp.instructions
.push(Instruction::Scan { out: comp.next_reg });
}
comp.add_todo(pattern, Id::from(last_i), comp.next_reg);
comp.add_todo(pattern, root, comp.next_reg);
};

if let Some(v) = patternbinder {
if let Some(&i) = self.v2r.get(&v) {
// patternbinder already bound
self.add_todo(pattern, Id::from(last_i), i);
self.add_todo(pattern, root, i);
} else {
// patternbinder is new variable
next_out.0 += 1;
Expand All @@ -236,7 +236,6 @@ impl<L: Language> Compiler<L> {
self.instructions.push(Instruction::Lookup {
i: reg,
term: extracted
.as_ref()
.iter()
.map(|n| match n {
ENodeOrVar::ENode(n) => ENodeOrReg::ENode(n.clone()),
Expand Down
10 changes: 5 additions & 5 deletions src/multipattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl<L: Language, A: Analysis<L>> Searcher<L, A> for MultiPattern<L> {
match self.asts.as_slice() {
[] => panic!("empty multipattern"),
[(_var, pat), ..] => {
if let [ENodeOrVar::Var(_)] = pat.as_ref() {
if let [ENodeOrVar::Var(_)] = **pat {
panic!(
"Bare cannot be first pattern variable in multipattern: {:?}",
self.asts
Expand All @@ -134,7 +134,7 @@ impl<L: Language, A: Analysis<L>> Searcher<L, A> for MultiPattern<L> {
let mut vars = vec![];
for (v, pat) in &self.asts {
vars.push(*v);
for n in pat.as_ref() {
for n in pat {
if let ENodeOrVar::Var(v) = n {
vars.push(*v)
}
Expand Down Expand Up @@ -172,8 +172,8 @@ impl<L: Language, A: Analysis<L>> Applier<L, A> for MultiPattern<L> {
let mut subst = subst.clone();
let mut id_buf = vec![];
for (i, (v, p)) in self.asts.iter().enumerate() {
id_buf.resize(p.as_ref().len(), 0.into());
let id1 = crate::pattern::apply_pat(&mut id_buf, p.as_ref(), egraph, &subst);
id_buf.resize(p.len(), 0.into());
let id1 = crate::pattern::apply_pat(&mut id_buf, p, egraph, &subst);
if let Some(id2) = subst.insert(*v, id1) {
egraph.union(id1, id2);
}
Expand All @@ -190,7 +190,7 @@ impl<L: Language, A: Analysis<L>> Applier<L, A> for MultiPattern<L> {
let mut bound_vars = HashSet::default();
let mut vars = vec![];
for (bv, pat) in &self.asts {
for n in pat.as_ref() {
for n in pat {
if let ENodeOrVar::Var(v) = n {
// using vars that are already bound doesn't count
if !bound_vars.contains(v) {
Expand Down
Loading

0 comments on commit fa40e29

Please sign in to comment.