diff --git a/src/egg/machine.rs.html b/src/egg/machine.rs.html index edd4159d..1ec567f8 100644 --- a/src/egg/machine.rs.html +++ b/src/egg/machine.rs.html @@ -394,16 +394,50 @@ 389 390 391 +392 +393 +394 +395 +396 +397 +398 +399 +400 +401 +402 +403 +404 +405 +406 +407
use crate::*;
-use std::result;
-type Result = result::Result<(), ()>;
-
-#[derive(Default)]
-struct Machine {
+struct Machine<'a, L: Language, N: Analysis<L>> {
reg: Vec<Id>,
// a buffer to re-use for lookups
lookup: Vec<Id>,
+ stack: Vec<MachineContext>,
+ instructions: &'a [Instruction<L>],
+ egraph: &'a EGraph<L, N>,
+ subst: Subst,
+}
+
+impl<'a, L: Language, N: Analysis<L>> Machine<'a, L, N> {
+ #[inline(always)]
+ fn reg(&self, reg: Reg) -> Id {
+ self.reg[reg.0 as usize]
+ }
+
+ fn new(instructions: &'a [Instruction<L>], egraph: &'a EGraph<L, N>, subst: Subst) -> Self {
+ Machine {
+ reg: Default::default(),
+ lookup: Default::default(),
+ stack: Default::default(),
+ instructions,
+ egraph,
+ subst,
+ }
+ }
}
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
@@ -417,7 +451,7 @@
#[derive(Debug, Clone, PartialEq, Eq)]
enum Instruction<L> {
- Bind { node: L, i: Reg, out: Reg },
+ Bind { node: L, eclass: Reg, out: Reg },
Compare { i: Reg, j: Reg },
Lookup { term: Vec<ENodeOrReg<L>>, i: Reg },
Scan { out: Reg },
@@ -433,8 +467,8 @@
fn for_each_matching_node<L, D>(
eclass: &EClass<L, D>,
node: &L,
- mut f: impl FnMut(&L) -> Result,
-) -> Result
+ mut f: impl FnMut(&L) -> (),
+)
where
L: Language,
{
@@ -444,7 +478,7 @@
.nodes
.iter()
.filter(|n| node.matches(n))
- .try_for_each(f)
+ .for_each(f)
} else {
debug_assert!(node.all(|id| id == Id::from(0)));
debug_assert!(eclass.nodes.windows(2).all(|w| w[0] < w[1]));
@@ -457,7 +491,7 @@
break;
}
}
- let mut matching = eclass.nodes[start..]
+ let matching = eclass.nodes[start..]
.iter()
.take_while(|&n| std::mem::discriminant(n) == discrim)
.filter(|n| node.matches(n));
@@ -475,78 +509,97 @@
.collect::<HashSet<_>>(),
eclass.nodes
);
- matching.try_for_each(&mut f)
+ matching.for_each(&mut f)
}
}
-impl Machine {
- #[inline(always)]
- fn reg(&self, reg: Reg) -> Id {
- self.reg[reg.0 as usize]
+struct MachineContext {
+ instruction_index: usize,
+ truncate: usize,
+ to_push: Vec<Id>,
+}
+
+impl MachineContext {
+ fn new(instruction_index: usize, truncate: usize, push: Vec<Id>) -> Self {
+ Self {
+ instruction_index,
+ truncate,
+ to_push: push
+ }
}
+}
- fn run<L, N>(
- &mut self,
- egraph: &EGraph<L, N>,
- instructions: &[Instruction<L>],
- subst: &Subst,
- yield_fn: &mut impl FnMut(&Self, &Subst) -> Result,
- ) -> Result
- where
- L: Language,
- N: Analysis<L>,
- {
- let mut instructions = instructions.iter();
- while let Some(instruction) = instructions.next() {
- match instruction {
- Instruction::Bind { i, out, node } => {
- let remaining_instructions = instructions.as_slice();
- return for_each_matching_node(&egraph[self.reg(*i)], node, |matched| {
- self.reg.truncate(out.0 as usize);
- matched.for_each(|id| self.reg.push(id));
- self.run(egraph, remaining_instructions, subst, yield_fn)
- });
- }
- Instruction::Scan { out } => {
- let remaining_instructions = instructions.as_slice();
- for class in egraph.classes() {
- self.reg.truncate(out.0 as usize);
- self.reg.push(class.id);
- self.run(egraph, remaining_instructions, subst, yield_fn)?
+impl<'a, L: Language, N: Analysis<L>> Iterator for Machine<'a, L, N> {
+ type Item = Subst;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ while !self.stack.is_empty() {
+ let current_state = self.stack.pop().unwrap();
+ self.reg.truncate(current_state.truncate);
+ for id in current_state.to_push {
+ self.reg.push(id);
+ }
+ let mut index = current_state.instruction_index;
+ 'instr: while index < self.instructions.len() {
+ let instruction = &self.instructions[index];
+ match instruction {
+ Instruction::Bind { eclass, out, node } => {
+ for_each_matching_node(&self.egraph[self.reg(*eclass)], node, |matched| {
+ let truncate = out.0 as usize;
+ let to_push = matched.children().iter().copied().collect();
+ self.stack.push(MachineContext::new(index + 1, truncate, to_push));
+ });
+ break;
}
- return Ok(());
- }
- Instruction::Compare { i, j } => {
- if egraph.find(self.reg(*i)) != egraph.find(self.reg(*j)) {
- return Ok(());
+ Instruction::Scan { out } => {
+ for class in self.egraph.classes() {
+ let truncate = out.0 as usize;
+ let to_push = vec![class.id];
+ self.stack.push(MachineContext::new(index + 1, truncate, to_push));
+ }
+ break;
}
- }
- Instruction::Lookup { term, i } => {
- self.lookup.clear();
- for node in term {
- match node {
- ENodeOrReg::ENode(node) => {
- let look = |i| self.lookup[usize::from(i)];
- match egraph.lookup(node.clone().map_children(look)) {
- Some(id) => self.lookup.push(id),
- None => return Ok(()),
+ Instruction::Compare { i, j } => {
+ if self.egraph.find(self.reg(*i)) != self.egraph.find(self.reg(*j)) {
+ break;
+ }
+ }
+ Instruction::Lookup { term, i } => {
+ self.lookup.clear();
+ for node in term {
+ match node {
+ ENodeOrReg::ENode(node) => {
+ let look = |i| self.lookup[usize::from(i)];
+ match self.egraph.lookup(node.clone().map_children(look)) {
+ Some(id) => self.lookup.push(id),
+ None => break 'instr,
+ }
+ }
+ ENodeOrReg::Reg(r) => {
+ self.lookup.push(self.egraph.find(self.reg(*r)));
}
- }
- ENodeOrReg::Reg(r) => {
- self.lookup.push(egraph.find(self.reg(*r)));
}
}
- }
- let id = egraph.find(self.reg(*i));
- if self.lookup.last().copied() != Some(id) {
- return Ok(());
+ let id = self.egraph.find(self.reg(*i));
+ if self.lookup.last().copied() != Some(id) {
+ break 'instr;
+ }
}
- }
+ };
+ index += 1;
+ }
+ if index == self.instructions.len() {
+ let subst_vec = self.subst
+ .vec
+ .iter()
+ // HACK we are reusing Ids here, this is bad
+ .map(|(v, reg_id)| (*v, self.reg(Reg(usize::from(*reg_id) as u32))))
+ .collect();
+ return Some(Subst { vec: subst_vec });
}
}
-
- yield_fn(self, subst)
+ return None;
}
}
@@ -696,7 +749,7 @@
// zero out the children so Bind can use it to sort
let op = node.clone().map_children(|_| Id::from(0));
self.instructions.push(Instruction::Bind {
- i: reg,
+ eclass: reg,
node: op,
out,
});
@@ -742,7 +795,7 @@
&self,
egraph: &EGraph<L, A>,
eclass: Id,
- mut limit: usize,
+ limit: usize,
) -> Vec<Subst>
where
A: Analysis<L>,
@@ -753,33 +806,12 @@
return vec![];
}
- let mut machine = Machine::default();
+ let mut machine = Machine::new(&self.instructions, egraph, self.subst.clone());
assert_eq!(machine.reg.len(), 0);
machine.reg.push(eclass);
+ machine.stack.push(MachineContext::new(0, 1, vec![]));
- let mut matches = Vec::new();
- machine
- .run(
- egraph,
- &self.instructions,
- &self.subst,
- &mut |machine, subst| {
- let subst_vec = subst
- .vec
- .iter()
- // HACK we are reusing Ids here, this is bad
- .map(|(v, reg_id)| (*v, machine.reg(Reg(usize::from(*reg_id) as u32))))
- .collect();
- matches.push(Subst { vec: subst_vec });
- limit -= 1;
- if limit != 0 {
- Ok(())
- } else {
- Err(())
- }
- },
- )
- .unwrap_or_default();
+ let matches = machine.into_iter().take(limit).collect();
log::trace!("Ran program, found {:?}", matches);
matches