diff --git a/src/egg/machine.rs.html b/src/egg/machine.rs.html index edd4159d..1ec567f8 100644 --- a/src/egg/machine.rs.html +++ b/src/egg/machine.rs.html @@ -394,16 +394,50 @@ 389 390 391 +392 +393 +394 +395 +396 +397 +398 +399 +400 +401 +402 +403 +404 +405 +406 +407
use crate::*;
-use std::result;
 
-type Result = result::Result<(), ()>;
-
-#[derive(Default)]
-struct Machine {
+struct Machine<'a, L: Language, N: Analysis<L>> {
     reg: Vec<Id>,
     // a buffer to re-use for lookups
     lookup: Vec<Id>,
+    stack: Vec<MachineContext>,
+    instructions: &'a [Instruction<L>],
+    egraph: &'a EGraph<L, N>,
+    subst: Subst,
+}
+
+impl<'a, L: Language, N: Analysis<L>> Machine<'a, L, N> {
+    #[inline(always)]
+    fn reg(&self, reg: Reg) -> Id {
+        self.reg[reg.0 as usize]
+    }
+
+    fn new(instructions: &'a [Instruction<L>], egraph: &'a EGraph<L, N>, subst: Subst) -> Self {
+        Machine {
+            reg: Default::default(),
+            lookup: Default::default(),
+            stack: Default::default(),
+            instructions,
+            egraph,
+            subst,
+        }
+    }
 }
 
 #[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
@@ -417,7 +451,7 @@
 
 #[derive(Debug, Clone, PartialEq, Eq)]
 enum Instruction<L> {
-    Bind { node: L, i: Reg, out: Reg },
+    Bind { node: L, eclass: Reg, out: Reg },
     Compare { i: Reg, j: Reg },
     Lookup { term: Vec<ENodeOrReg<L>>, i: Reg },
     Scan { out: Reg },
@@ -433,8 +467,8 @@
 fn for_each_matching_node<L, D>(
     eclass: &EClass<L, D>,
     node: &L,
-    mut f: impl FnMut(&L) -> Result,
-) -> Result
+    mut f: impl FnMut(&L) -> (),
+)
 where
     L: Language,
 {
@@ -444,7 +478,7 @@
             .nodes
             .iter()
             .filter(|n| node.matches(n))
-            .try_for_each(f)
+            .for_each(f)
     } else {
         debug_assert!(node.all(|id| id == Id::from(0)));
         debug_assert!(eclass.nodes.windows(2).all(|w| w[0] < w[1]));
@@ -457,7 +491,7 @@
                 break;
             }
         }
-        let mut matching = eclass.nodes[start..]
+        let matching = eclass.nodes[start..]
             .iter()
             .take_while(|&n| std::mem::discriminant(n) == discrim)
             .filter(|n| node.matches(n));
@@ -475,78 +509,97 @@
                 .collect::<HashSet<_>>(),
             eclass.nodes
         );
-        matching.try_for_each(&mut f)
+        matching.for_each(&mut f)
     }
 }
 
-impl Machine {
-    #[inline(always)]
-    fn reg(&self, reg: Reg) -> Id {
-        self.reg[reg.0 as usize]
+struct MachineContext {
+    instruction_index: usize,
+    truncate: usize,
+    to_push: Vec<Id>,
+}
+
+impl MachineContext {
+    fn new(instruction_index: usize, truncate: usize, push: Vec<Id>) -> Self {
+        Self {
+            instruction_index,
+            truncate,
+            to_push: push
+        }
     }
+}
 
-    fn run<L, N>(
-        &mut self,
-        egraph: &EGraph<L, N>,
-        instructions: &[Instruction<L>],
-        subst: &Subst,
-        yield_fn: &mut impl FnMut(&Self, &Subst) -> Result,
-    ) -> Result
-    where
-        L: Language,
-        N: Analysis<L>,
-    {
-        let mut instructions = instructions.iter();
-        while let Some(instruction) = instructions.next() {
-            match instruction {
-                Instruction::Bind { i, out, node } => {
-                    let remaining_instructions = instructions.as_slice();
-                    return for_each_matching_node(&egraph[self.reg(*i)], node, |matched| {
-                        self.reg.truncate(out.0 as usize);
-                        matched.for_each(|id| self.reg.push(id));
-                        self.run(egraph, remaining_instructions, subst, yield_fn)
-                    });
-                }
-                Instruction::Scan { out } => {
-                    let remaining_instructions = instructions.as_slice();
-                    for class in egraph.classes() {
-                        self.reg.truncate(out.0 as usize);
-                        self.reg.push(class.id);
-                        self.run(egraph, remaining_instructions, subst, yield_fn)?
+impl<'a, L: Language, N: Analysis<L>> Iterator for Machine<'a, L, N> {
+    type Item = Subst;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        while !self.stack.is_empty() {
+            let current_state = self.stack.pop().unwrap();
+            self.reg.truncate(current_state.truncate);
+            for id in current_state.to_push {
+                self.reg.push(id);
+            }
+            let mut index = current_state.instruction_index;
+            'instr: while index < self.instructions.len() {
+                let instruction = &self.instructions[index];
+                match instruction {
+                    Instruction::Bind { eclass, out, node } => {
+                        for_each_matching_node(&self.egraph[self.reg(*eclass)], node, |matched| {
+                            let truncate = out.0 as usize;
+                            let to_push = matched.children().iter().copied().collect();
+                            self.stack.push(MachineContext::new(index + 1, truncate, to_push));
+                        });
+                        break;
                     }
-                    return Ok(());
-                }
-                Instruction::Compare { i, j } => {
-                    if egraph.find(self.reg(*i)) != egraph.find(self.reg(*j)) {
-                        return Ok(());
+                    Instruction::Scan { out } => {
+                        for class in self.egraph.classes() {
+                            let truncate = out.0 as usize;
+                            let to_push = vec![class.id];
+                            self.stack.push(MachineContext::new(index + 1, truncate, to_push));
+                        }
+                        break;
                     }
-                }
-                Instruction::Lookup { term, i } => {
-                    self.lookup.clear();
-                    for node in term {
-                        match node {
-                            ENodeOrReg::ENode(node) => {
-                                let look = |i| self.lookup[usize::from(i)];
-                                match egraph.lookup(node.clone().map_children(look)) {
-                                    Some(id) => self.lookup.push(id),
-                                    None => return Ok(()),
+                    Instruction::Compare { i, j } => {
+                        if self.egraph.find(self.reg(*i)) != self.egraph.find(self.reg(*j)) {
+                            break;
+                        }
+                    }
+                    Instruction::Lookup { term, i } => {
+                        self.lookup.clear();
+                        for node in term {
+                            match node {
+                                ENodeOrReg::ENode(node) => {
+                                    let look = |i| self.lookup[usize::from(i)];
+                                    match self.egraph.lookup(node.clone().map_children(look)) {
+                                        Some(id) => self.lookup.push(id),
+                                        None => break 'instr,
+                                    }
+                                }
+                                ENodeOrReg::Reg(r) => {
+                                    self.lookup.push(self.egraph.find(self.reg(*r)));
                                 }
-                            }
-                            ENodeOrReg::Reg(r) => {
-                                self.lookup.push(egraph.find(self.reg(*r)));
                             }
                         }
-                    }
 
-                    let id = egraph.find(self.reg(*i));
-                    if self.lookup.last().copied() != Some(id) {
-                        return Ok(());
+                        let id = self.egraph.find(self.reg(*i));
+                        if self.lookup.last().copied() != Some(id) {
+                            break 'instr;
+                        }
                     }
-                }
+                };
+                index += 1;
+            }
+            if index == self.instructions.len() {
+                let subst_vec = self.subst
+                    .vec
+                    .iter()
+                    // HACK we are reusing Ids here, this is bad
+                    .map(|(v, reg_id)| (*v, self.reg(Reg(usize::from(*reg_id) as u32))))
+                    .collect();
+                return Some(Subst { vec: subst_vec });
             }
         }
-
-        yield_fn(self, subst)
+        return None;
     }
 }
 
@@ -696,7 +749,7 @@
                 // zero out the children so Bind can use it to sort
                 let op = node.clone().map_children(|_| Id::from(0));
                 self.instructions.push(Instruction::Bind {
-                    i: reg,
+                    eclass: reg,
                     node: op,
                     out,
                 });
@@ -742,7 +795,7 @@
         &self,
         egraph: &EGraph<L, A>,
         eclass: Id,
-        mut limit: usize,
+        limit: usize,
     ) -> Vec<Subst>
     where
         A: Analysis<L>,
@@ -753,33 +806,12 @@
             return vec![];
         }
 
-        let mut machine = Machine::default();
+        let mut machine = Machine::new(&self.instructions, egraph, self.subst.clone());
         assert_eq!(machine.reg.len(), 0);
         machine.reg.push(eclass);
+        machine.stack.push(MachineContext::new(0,  1, vec![]));
 
-        let mut matches = Vec::new();
-        machine
-            .run(
-                egraph,
-                &self.instructions,
-                &self.subst,
-                &mut |machine, subst| {
-                    let subst_vec = subst
-                        .vec
-                        .iter()
-                        // HACK we are reusing Ids here, this is bad
-                        .map(|(v, reg_id)| (*v, machine.reg(Reg(usize::from(*reg_id) as u32))))
-                        .collect();
-                    matches.push(Subst { vec: subst_vec });
-                    limit -= 1;
-                    if limit != 0 {
-                        Ok(())
-                    } else {
-                        Err(())
-                    }
-                },
-            )
-            .unwrap_or_default();
+        let matches = machine.into_iter().take(limit).collect();
 
         log::trace!("Ran program, found {:?}", matches);
         matches