Skip to content

Commit

Permalink
Fix improved ematching on ground terms (egraphs-good#92)
Browse files Browse the repository at this point in the history
* improve matching performance for constant constraint

* improved matching performance for ground terms

* fix typo

* style

* cache once and better planning with TodoList

* add lookup_expr and refactor machine

* delete debug println

* fix a typo

* fix issue 91
  • Loading branch information
yihozhang authored May 14, 2021
1 parent 7d9c4a1 commit a67cdcf
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 8 deletions.
12 changes: 12 additions & 0 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,18 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
id.map(|&id| self.find(id))
}

/// Lookup the eclass of the given [`RecExpr`].
pub fn lookup_expr(&self, expr: &RecExpr<L>) -> Option<Id> {
let nodes = expr.as_ref();
let mut new_ids = Vec::with_capacity(nodes.len());
for node in nodes {
let node = node.clone().map_children(|i| new_ids[usize::from(i)]);
let id = self.lookup(node)?;
new_ids.push(id)
}
Some(*new_ids.last().unwrap())
}

/// Adds an enode to the [`EGraph`].
///
/// When adding an enode, to the egraph, [`add`] it performs
Expand Down
113 changes: 105 additions & 8 deletions src/machine.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::util::HashSet;
use crate::{Analysis, EClass, EGraph, ENodeOrVar, Id, Language, PatternAst, Subst, Var};
use crate::{Analysis, EClass, EGraph, ENodeOrVar, Id, Language, PatternAst, RecExpr, Subst, Var};
use std::cmp::Ordering;

struct Machine {
Expand All @@ -18,6 +18,7 @@ struct Reg(u32);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Program<L> {
instructions: Vec<Instruction<L>>,
ground_terms: Vec<RecExpr<L>>,
subst: Subst,
}

Expand Down Expand Up @@ -114,6 +115,8 @@ type TodoList<L> = std::collections::BinaryHeap<Todo<L>>;
#[derive(PartialEq, Eq)]
struct Todo<L> {
reg: Reg,
is_ground: bool,
loc: usize,
pat: ENodeOrVar<L>,
}

Expand All @@ -127,6 +130,11 @@ impl<L: Language> Ord for Todo<L> {
// TodoList is a max-heap, so we greater is higher priority
fn cmp(&self, other: &Self) -> Ordering {
use ENodeOrVar::*;
match (&self.is_ground, &other.is_ground) {
(true, false) => return Ordering::Greater,
(false, true) => return Ordering::Less,
_ => (),
};
match (&self.pat, &other.pat) {
// fewer children means higher priority
(ENode(e1), ENode(e2)) => e2.len().cmp(&e1.len()),
Expand All @@ -147,23 +155,96 @@ struct Compiler<'a, L> {

impl<'a, L: Language> Compiler<'a, L> {
fn compile(pattern: &'a [ENodeOrVar<L>]) -> Program<L> {
let last = pattern.last().unwrap();
let mut compiler = Self {
pattern,
v2r: Default::default(),
todo: Default::default(),
out: Reg(1),
out: Reg(0),
};
compiler.todo.push(Todo {
reg: Reg(0),
pat: last.clone(),
});
compiler.go()
}

fn get_ground_locs(&mut self, is_ground: &Vec<bool>) -> Vec<Option<Reg>> {
let mut ground_locs: Vec<Option<Reg>> = vec![None; self.pattern.len()];
for i in 0..self.pattern.len() {
if let ENodeOrVar::ENode(node) = &self.pattern[i] {
if !is_ground[i] {
node.for_each(|c| {
let c = usize::from(c);
// If a ground pattern is a child of a non-ground pattern,
// we load the ground pattern
if is_ground[c] && ground_locs[c].is_none() {
if let ENodeOrVar::ENode(_) = &self.pattern[c] {
ground_locs[c] = Some(self.out);
self.out.0 += 1;
} else {
unreachable!("ground locs");
}
}
})
}
}
}
if *is_ground.last().unwrap() {
if let Some(_) = self.pattern.last() {
*ground_locs.last_mut().unwrap() = Some(self.out);
self.out.0 += 1;
} else {
unreachable!("ground locs");
}
}
ground_locs
}

fn build_ground_terms(&self, loc: usize, expr: &mut RecExpr<L>) {
if let ENodeOrVar::ENode(mut node) = self.pattern[loc].clone() {
node.update_children(|c| {
self.build_ground_terms(usize::from(c), expr);
(expr.as_ref().len() - 1).into()
});
expr.add(node);
} else {
panic!("could only build ground terms");
}
}

fn go(&mut self) -> Program<L> {
let mut is_ground: Vec<bool> = vec![false; self.pattern.len()];
for i in 0..self.pattern.len() {
if let ENodeOrVar::ENode(node) = &self.pattern[i] {
is_ground[i] = node.all(|c| is_ground[usize::from(c)]);
}
}

let ground_locs = self.get_ground_locs(&is_ground);
let mut ground_terms: Vec<(u32, RecExpr<L>)> = ground_locs
.iter()
.enumerate()
.filter_map(|(i, r)| r.map(|r| (i, r.0)))
.map(|(i, r)| {
let mut expr = Default::default();
self.build_ground_terms(i, &mut expr);
(r, expr)
})
.collect();
ground_terms.sort_by_key(|(r, _expr)| *r);
let ground_terms: Vec<RecExpr<L>> =
ground_terms.into_iter().map(|(_r, expr)| expr).collect();

self.todo.push(Todo {
reg: Reg(self.out.0),
is_ground: is_ground[self.pattern.len() - 1],
loc: self.pattern.len() - 1,
pat: self.pattern.last().unwrap().clone(),
});
self.out.0 += 1;

let mut instructions = vec![];
while let Some(Todo { reg: i, pat }) = self.todo.pop() {

while let Some(Todo {
reg: i, pat, loc, ..
}) = self.todo.pop()
{
match pat {
ENodeOrVar::Var(v) => {
if let Some(&j) = self.v2r.get(&v) {
Expand All @@ -173,6 +254,11 @@ impl<'a, L: Language> Compiler<'a, L> {
}
}
ENodeOrVar::ENode(node) => {
if let Some(j) = ground_locs[loc] {
instructions.push(Instruction::Compare { i, j });
continue;
}

let out = self.out;
self.out.0 += node.len() as u32;

Expand All @@ -181,6 +267,8 @@ impl<'a, L: Language> Compiler<'a, L> {
let r = Reg(out.0 + id as u32);
self.todo.push(Todo {
reg: r,
is_ground: is_ground[usize::from(child)],
loc: usize::from(child),
pat: self.pattern[usize::from(child)].clone(),
});
id += 1;
Expand All @@ -197,9 +285,11 @@ impl<'a, L: Language> Compiler<'a, L> {
for (v, r) in &self.v2r {
subst.insert(*v, Id::from(r.0 as usize));
}

Program {
instructions,
subst,
ground_terms,
}
}
}
Expand All @@ -218,6 +308,13 @@ impl<L: Language> Program<L> {
let mut machine = Machine::default();

assert_eq!(machine.reg.len(), 0);
for expr in &self.ground_terms {
if let Some(id) = egraph.lookup_expr(&mut expr.clone()) {
machine.reg.push(id)
} else {
return vec![];
}
}
machine.reg.push(eclass);

let mut substs = Vec::new();
Expand Down

0 comments on commit a67cdcf

Please sign in to comment.