Skip to content

Support tail call in the WASM interpreter #514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions crates/interpreter/src/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,13 @@ impl<'m> Thread<'m> {
return Ok(self.pop_label(inst, ls.get(i).cloned().unwrap_or(ln)));
}
Return => return Ok(self.exit_frame()),
Call(x) => return self.invoke(store, store.func_ptr(inst_id, x)),
Call(x) => {
return if self.parser.is_tail_call() {
self.invoke_tail_call(store, store.func_ptr(inst_id, x))
} else {
self.invoke(store, store.func_ptr(inst_id, x))
}
}
CallIndirect(x, y) => {
let i = self.pop_value().unwrap_i32();
let x = match store.table(inst_id, x).elems.get(i as usize) {
Expand All @@ -805,7 +811,11 @@ impl<'m> Thread<'m> {
if store.func_type(x) != store.insts[inst_id].module.types()[y as usize] {
return Err(trap());
}
return self.invoke(store, x);
return if self.parser.is_tail_call() {
self.invoke_tail_call(store, x)
} else {
self.invoke(store, x)
};
}
Drop => drop(self.pop_value()),
Select(_) => {
Expand Down Expand Up @@ -1369,6 +1379,32 @@ impl<'m> Thread<'m> {
self.frames.push(Frame::new(inst_id, t.results.len(), ret, locals));
Ok(ThreadResult::Continue(self))
}

fn invoke_tail_call(
mut self, store: &mut Store<'m>, ptr: Ptr,
) -> Result<ThreadResult<'m>, Error> {
let t = store.func_type(ptr);
match ptr.instance() {
Side::Host => {
let index = ptr.index() as usize;
let t = store.funcs[index].1;
let arity = t.results.len();
let args = self.pop_values(t.params.len());
store.threads.push(Continuation { thread: self, arity, index, args });
Ok(ThreadResult::Host)
}
Side::Wasm(inst_id) => {
let mut parser = store.insts[inst_id].module.func(ptr.index());

// Reuse the existing frame (no push)
self.frame().locals = self.pop_values(t.params.len());
append_locals(&mut parser, &mut self.frame().locals);

self.parser = parser;
Ok(ThreadResult::Continue(self))
}
}
}
}

fn table_init(d: usize, s: usize, n: usize, table: &mut Table, elems: &[Val]) -> Result<(), Error> {
Expand Down
76 changes: 76 additions & 0 deletions crates/interpreter/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,82 @@ impl<'m> Parser<'m, Use> {
pub unsafe fn new(data: &'m [u8]) -> Self {
Self::internal_new(data)
}

pub fn is_tail_call(&self) -> bool {
let mut remaining = self.data;
let mut block_depth = 0;
let mut call_depth = 0;

while !remaining.is_empty() {
let mut temp_parser: Parser<'_, Check> = Parser { data: remaining, mode: PhantomData };

if remaining.len() == 0 {
return false;
} else if remaining.len() == 1 {
return remaining[0] == 0x0B && block_depth == 0 && call_depth == 1;
}

let opcode = temp_parser.parse_byte().unwrap();
remaining = temp_parser.data;

if opcode == 0x02 || opcode == 0x03 || opcode == 0x04 {
block_depth += 1;
} else if opcode == 0x0B {
block_depth -= 1;
if call_depth > 0 && block_depth == 0 {
call_depth -= 1;
}
} else if opcode == 0x10 || opcode == 0x11 {
call_depth += 1;
if block_depth == 0 && call_depth == 1 {
return !remaining.is_empty() && remaining[0] == 0x0B;
}
}

match opcode {
0x0E => {
// br_table
let _num_labels = temp_parser.parse_u32().unwrap();
for _ in 0 .. _num_labels + 1 {
temp_parser.parse_labelidx().unwrap();
}
remaining = temp_parser.data;
}
0x28 ..= 0x3E => {
// Memory instructions
temp_parser.parse_memarg().unwrap();
remaining = temp_parser.data;
}
0xFC => {
let fc_opcode = temp_parser.parse_u32().unwrap();
match fc_opcode {
0 ..= 3 => {
// Using range pattern
temp_parser.parse_leb128(true, 33).unwrap();
remaining = temp_parser.data;
}
4 => {
temp_parser.parse_dataidx().unwrap();
temp_parser.parse_byte().unwrap();
remaining = temp_parser.data;
}
5 | 6 => {
temp_parser.parse_elemidx().unwrap();
remaining = temp_parser.data;
}
7 => {
temp_parser.parse_tableidx().unwrap();
temp_parser.parse_elemidx().unwrap();
remaining = temp_parser.data;
}
_ => (), // Unsupported or no arguments
}
}
_ => (), // Other instructions with no immediate arguments
}
}
false // Not a tail call if reached end without finding one
}
}

impl<'m, M: Mode> Parser<'m, M> {
Expand Down
Loading