Skip to content

Commit

Permalink
Search for partial ground terms as well
Browse files Browse the repository at this point in the history
  • Loading branch information
mwillsey committed Nov 16, 2021
1 parent 9ac1789 commit 097a9e5
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 28 deletions.
1 change: 1 addition & 0 deletions src/language.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ impl<L: Language> RecExpr<L> {
let mut todo = vec![new_root];
while let Some(id) = todo.last().copied() {
if ids.contains_key(&id) {
todo.pop();
continue;
}

Expand Down
81 changes: 53 additions & 28 deletions src/machine.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
use crate::*;
use std::cmp::Ordering;

#[derive(Default)]
struct Machine {
reg: Vec<Id>,
}

impl Default for Machine {
fn default() -> Self {
Self { reg: vec![] }
}
// a buffer to re-use for lookups
lookup: Vec<Id>,
}

#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
Expand All @@ -24,7 +21,13 @@ pub struct Program<L> {
enum Instruction<L> {
Bind { node: L, i: Reg, out: Reg },
Compare { i: Reg, j: Reg },
Lookup { term: PatternAst<L>, i: Reg },
Lookup { term: Vec<ENodeOrReg<L>>, i: Reg },
}

#[derive(Debug, Clone, PartialEq, Eq)]
enum ENodeOrReg<L> {
ENode(L),
Reg(Reg),
}

#[inline(always)]
Expand Down Expand Up @@ -102,25 +105,24 @@ impl Machine {
}
}
Instruction::Lookup { term, i } => {
let mut new_ids = Vec::with_capacity(term.as_ref().len());
for node in term.as_ref() {
self.lookup.clear();
for node in term {
match node {
ENodeOrVar::ENode(node) => {
let node = node.clone().map_children(|i| new_ids[usize::from(i)]);
match egraph.lookup(node) {
Some(id) => new_ids.push(id),
ENodeOrReg::ENode(node) => {
let look = |i| self.lookup[usize::from(i)];
match egraph.lookup(node.clone().map_children(look)) {
Some(id) => self.lookup.push(id),
None => return,
}
}
ENodeOrVar::Var(_) => {
panic!("Lookup instruction only supports ground terms right now")
// in the future, this could ids for registers
ENodeOrReg::Reg(r) => {
self.lookup.push(egraph.find(self.reg(*r)));
}
}
}

let id = egraph.find(self.reg(*i));
if new_ids.last().copied() != Some(id) {
if self.lookup.last().copied() != Some(id) {
return;
}
}
Expand All @@ -137,7 +139,7 @@ type TodoList<L> = std::collections::BinaryHeap<Todo<L>>;
#[derive(PartialEq, Eq)]
struct Todo<L> {
reg: Reg,
is_ground: bool,
n_free: usize,
pat: ENodeOrVar<L>,
/// location within the pattern
id: Id,
Expand All @@ -153,8 +155,8 @@ impl<L: Language> Ord for Todo<L> {
// TodoList is a max-heap, so we greater is higher priority
fn cmp(&self, other: &Self) -> Ordering {
use ENodeOrVar::*;
let cmp_ground = self.is_ground.cmp(&other.is_ground);
cmp_ground.then_with(|| match (&self.pat, &other.pat) {
let cmp_free = self.n_free.cmp(&other.n_free);
cmp_free.then_with(|| match (&self.pat, &other.pat) {
// fewer children means higher priority
(ENode(e1), ENode(e2)) => e2.len().cmp(&e1.len()),
// Var is higher prio than enode
Expand Down Expand Up @@ -186,17 +188,27 @@ impl<'a, L: Language> Compiler<'a, L> {
fn go(&mut self) -> Program<L> {
let nodes = self.pattern.as_ref();
let len = nodes.len();
let mut is_ground: Vec<bool> = vec![false; len];
for (i, node) in nodes.iter().enumerate() {
if let ENodeOrVar::ENode(node) = node {
is_ground[i] = node.all(|c| is_ground[usize::from(c)]);

let mut free_vars: Vec<HashSet<Var>> = Vec::with_capacity(len);
for node in nodes {
let mut free = HashSet::default();
match node {
ENodeOrVar::ENode(n) => {
for &child in n.children() {
free.extend(&free_vars[usize::from(child)])
}
}
ENodeOrVar::Var(v) => {
free.insert(*v);
}
}
free_vars.push(free)
}

let last_i = len - 1;
self.todo.push(Todo {
reg: Reg(self.out.0),
is_ground: is_ground[last_i],
n_free: free_vars[last_i].len(),
id: Id::from(last_i),
pat: nodes[last_i].clone(),
});
Expand All @@ -214,10 +226,23 @@ impl<'a, L: Language> Compiler<'a, L> {
}
}
ENodeOrVar::ENode(node) => {
if todo.is_ground && !node.is_leaf() {
// check to see if this e-node corresponds to a term
// that is grounded by the variables bound at this point
let is_ground_now = free_vars[usize::from(todo.id)]
.iter()
.all(|v| self.v2r.contains_key(v));
if is_ground_now && !node.is_leaf() {
let pattern = self.pattern.extract(todo.id);
instructions.push(Instruction::Lookup {
i: todo.reg,
term: self.pattern.extract(todo.id),
term: pattern
.as_ref()
.iter()
.map(|n| match n {
ENodeOrVar::ENode(n) => ENodeOrReg::ENode(n.clone()),
ENodeOrVar::Var(v) => ENodeOrReg::Reg(self.v2r[v]),
})
.collect(),
});
continue;
}
Expand All @@ -229,7 +254,7 @@ impl<'a, L: Language> Compiler<'a, L> {
let r = Reg(out.0 + id as u32);
self.todo.push(Todo {
reg: r,
is_ground: is_ground[usize::from(child)],
n_free: free_vars[usize::from(child)].len(),
id: child,
pat: nodes[usize::from(child)].clone(),
});
Expand Down

0 comments on commit 097a9e5

Please sign in to comment.