diff --git a/crates/interpreter/src/exec.rs b/crates/interpreter/src/exec.rs index a205e16a9..02207d5dc 100644 --- a/crates/interpreter/src/exec.rs +++ b/crates/interpreter/src/exec.rs @@ -795,7 +795,13 @@ impl<'m> Thread<'m> { return Ok(self.pop_label(inst, ls.get(i).cloned().unwrap_or(ln))); } Return => return Ok(self.exit_frame()), - Call(x) => return self.invoke(store, store.func_ptr(inst_id, x)), + Call(x) => { + return if self.parser.is_tail_call() { + self.invoke_tail_call(store, store.func_ptr(inst_id, x)) + } else { + self.invoke(store, store.func_ptr(inst_id, x)) + } + } CallIndirect(x, y) => { let i = self.pop_value().unwrap_i32(); let x = match store.table(inst_id, x).elems.get(i as usize) { @@ -805,7 +811,11 @@ impl<'m> Thread<'m> { if store.func_type(x) != store.insts[inst_id].module.types()[y as usize] { return Err(trap()); } - return self.invoke(store, x); + return if self.parser.is_tail_call() { + self.invoke_tail_call(store, x) + } else { + self.invoke(store, x) + }; } Drop => drop(self.pop_value()), Select(_) => { @@ -1369,6 +1379,32 @@ impl<'m> Thread<'m> { self.frames.push(Frame::new(inst_id, t.results.len(), ret, locals)); Ok(ThreadResult::Continue(self)) } + + fn invoke_tail_call( + mut self, store: &mut Store<'m>, ptr: Ptr, + ) -> Result, Error> { + let t = store.func_type(ptr); + match ptr.instance() { + Side::Host => { + let index = ptr.index() as usize; + let t = store.funcs[index].1; + let arity = t.results.len(); + let args = self.pop_values(t.params.len()); + store.threads.push(Continuation { thread: self, arity, index, args }); + Ok(ThreadResult::Host) + } + Side::Wasm(inst_id) => { + let mut parser = store.insts[inst_id].module.func(ptr.index()); + + // Reuse the existing frame (no push) + self.frame().locals = self.pop_values(t.params.len()); + append_locals(&mut parser, &mut self.frame().locals); + + self.parser = parser; + Ok(ThreadResult::Continue(self)) + } + } + } } fn table_init(d: usize, s: usize, n: usize, table: &mut Table, elems: &[Val]) -> Result<(), Error> { diff --git a/crates/interpreter/src/parser.rs b/crates/interpreter/src/parser.rs index 8755e1dec..3b2e26a95 100644 --- a/crates/interpreter/src/parser.rs +++ b/crates/interpreter/src/parser.rs @@ -37,6 +37,82 @@ impl<'m> Parser<'m, Use> { pub unsafe fn new(data: &'m [u8]) -> Self { Self::internal_new(data) } + + pub fn is_tail_call(&self) -> bool { + let mut remaining = self.data; + let mut block_depth = 0; + let mut call_depth = 0; + + while !remaining.is_empty() { + let mut temp_parser: Parser<'_, Check> = Parser { data: remaining, mode: PhantomData }; + + if remaining.len() == 0 { + return false; + } else if remaining.len() == 1 { + return remaining[0] == 0x0B && block_depth == 0 && call_depth == 1; + } + + let opcode = temp_parser.parse_byte().unwrap(); + remaining = temp_parser.data; + + if opcode == 0x02 || opcode == 0x03 || opcode == 0x04 { + block_depth += 1; + } else if opcode == 0x0B { + block_depth -= 1; + if call_depth > 0 && block_depth == 0 { + call_depth -= 1; + } + } else if opcode == 0x10 || opcode == 0x11 { + call_depth += 1; + if block_depth == 0 && call_depth == 1 { + return !remaining.is_empty() && remaining[0] == 0x0B; + } + } + + match opcode { + 0x0E => { + // br_table + let _num_labels = temp_parser.parse_u32().unwrap(); + for _ in 0 .. _num_labels + 1 { + temp_parser.parse_labelidx().unwrap(); + } + remaining = temp_parser.data; + } + 0x28 ..= 0x3E => { + // Memory instructions + temp_parser.parse_memarg().unwrap(); + remaining = temp_parser.data; + } + 0xFC => { + let fc_opcode = temp_parser.parse_u32().unwrap(); + match fc_opcode { + 0 ..= 3 => { + // Using range pattern + temp_parser.parse_leb128(true, 33).unwrap(); + remaining = temp_parser.data; + } + 4 => { + temp_parser.parse_dataidx().unwrap(); + temp_parser.parse_byte().unwrap(); + remaining = temp_parser.data; + } + 5 | 6 => { + temp_parser.parse_elemidx().unwrap(); + remaining = temp_parser.data; + } + 7 => { + temp_parser.parse_tableidx().unwrap(); + temp_parser.parse_elemidx().unwrap(); + remaining = temp_parser.data; + } + _ => (), // Unsupported or no arguments + } + } + _ => (), // Other instructions with no immediate arguments + } + } + false // Not a tail call if reached end without finding one + } } impl<'m, M: Mode> Parser<'m, M> {