Skip to content

Use EntryStoreContext to manage state when entering and exiting Wasm #10626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 137 additions & 85 deletions crates/wasmtime/src/runtime/func.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::prelude::*;
use crate::runtime::vm::{
ExportFunction, InterpreterRef, SendSyncPtr, StoreBox, VMArrayCallHostFuncContext, VMContext,
VMFuncRef, VMFunctionImport, VMOpaqueContext,
VMFuncRef, VMFunctionImport, VMOpaqueContext, VMStoreContext,
};
use crate::runtime::Uninhabited;
use crate::store::{AutoAssertNoGc, StoreData, StoreOpaque, Stored};
Expand Down Expand Up @@ -1604,103 +1604,155 @@ pub(crate) fn invoke_wasm_and_catch_traps<T>(
closure: impl FnMut(NonNull<VMContext>, Option<InterpreterRef<'_>>) -> bool,
) -> Result<()> {
unsafe {
let exit = enter_wasm(store);
let previous_runtime_state = EntryStoreContext::enter_wasm(store);

if let Err(trap) = store.0.call_hook(CallHook::CallingWasm) {
exit_wasm(store, exit);
// `previous_runtime_state` implicitly dropped here
return Err(trap);
}
let result = crate::runtime::vm::catch_traps(store, closure);
exit_wasm(store, exit);
let result = crate::runtime::vm::catch_traps(store, &previous_runtime_state, closure);
core::mem::drop(previous_runtime_state);
store.0.call_hook(CallHook::ReturningFromWasm)?;
result.map_err(|t| crate::trap::from_runtime_box(store.0, t))
}
}

/// This function is called to register state within `Store` whenever
/// WebAssembly is entered within the `Store`.
///
/// This function sets up various limits such as:
///
/// * The stack limit. This is what ensures that we limit the stack space
/// allocated by WebAssembly code and it's relative to the initial stack
/// pointer that called into wasm.
///
/// This function may fail if the stack limit can't be set because an
/// interrupt already happened.
fn enter_wasm<T>(store: &mut StoreContextMut<'_, T>) -> Option<usize> {
// If this is a recursive call, e.g. our stack limit is already set, then
// we may be able to skip this function.
//
// For synchronous stores there's nothing else to do because all wasm calls
// happen synchronously and on the same stack. This means that the previous
// stack limit will suffice for the next recursive call.
//
// For asynchronous stores then each call happens on a separate native
// stack. This means that the previous stack limit is no longer relevant
// because we're on a separate stack.
if unsafe { *store.0.vm_store_context().stack_limit.get() } != usize::MAX
&& !store.0.async_support()
{
return None;
}

// Ignore this stack pointer business on miri since we can't execute wasm
// anyway and the concept of a stack pointer on miri is a bit nebulous
// regardless.
if cfg!(miri) {
return None;
}

// When Cranelift has support for the host then we might be running native
// compiled code meaning we need to read the actual stack pointer. If
// Cranelift can't be used though then we're guaranteed to be running pulley
// in which case this stack pointer isn't actually used as Pulley has custom
// mechanisms for stack overflow.
#[cfg(has_host_compiler_backend)]
let stack_pointer = crate::runtime::vm::get_stack_pointer();
#[cfg(not(has_host_compiler_backend))]
let stack_pointer = {
use wasmtime_environ::TripleExt;
debug_assert!(store.engine().target().is_pulley());
usize::MAX
};
/// This type helps managing the state of the runtime when entering and exiting
/// Wasm. To this end, it contains a subset of the data in `VMStoreContext`.
/// Upon entering Wasm, it updates various runtime fields and their
/// original values saved in this struct. Upon exiting Wasm, the previous values
/// are restored.
pub(crate) struct EntryStoreContext {
/// If set, contains value of `stack_limit` field to restore in
/// `VMRuntimeLimits` when exiting Wasm.
pub stack_limit: Option<usize>,
/// Contains value of `last_wasm_exit_pc` field to restore in
/// `VMStoreContext` when exiting Wasm.
pub last_wasm_exit_pc: usize,
/// Contains value of `last_wasm_exit_fp` field to restore in
/// `VMStoreContext` when exiting Wasm.
pub last_wasm_exit_fp: usize,
/// Contains value of `last_wasm_entry_fp` field to restore in
/// `VMStoreContext` when exiting Wasm.
pub last_wasm_entry_fp: usize,

/// We need a pointer to the runtime limits, so we can update them from
/// `drop`/`exit_wasm`.
vm_store_context: *const VMStoreContext,
}

// Determine the stack pointer where, after which, any wasm code will
// immediately trap. This is checked on the entry to all wasm functions.
//
// Note that this isn't 100% precise. We are requested to give wasm
// `max_wasm_stack` bytes, but what we're actually doing is giving wasm
// probably a little less than `max_wasm_stack` because we're
// calculating the limit relative to this function's approximate stack
// pointer. Wasm will be executed on a frame beneath this one (or next
// to it). In any case it's expected to be at most a few hundred bytes
// of slop one way or another. When wasm is typically given a MB or so
// (a million bytes) the slop shouldn't matter too much.
//
// After we've got the stack limit then we store it into the `stack_limit`
// variable.
let wasm_stack_limit = stack_pointer - store.engine().config().max_wasm_stack;
let prev_stack = unsafe {
mem::replace(
&mut *store.0.vm_store_context().stack_limit.get(),
wasm_stack_limit,
)
};
impl EntryStoreContext {
/// This function is called to update and save state when
/// WebAssembly is entered within the `Store`.
///
/// This updates various fields such as:
///
/// * The stack limit. This is what ensures that we limit the stack space
/// allocated by WebAssembly code and it's relative to the initial stack
/// pointer that called into wasm.
///
/// It also saves the different last_wasm_* values in the `VMRuntimeLimits`.
pub fn enter_wasm<T>(store: &mut StoreContextMut<'_, T>) -> Self {
let stack_limit;

Some(prev_stack)
}
// If this is a recursive call, e.g. our stack limit is already set, then
// we may be able to skip this function.
//
// For synchronous stores there's nothing else to do because all wasm calls
// happen synchronously and on the same stack. This means that the previous
// stack limit will suffice for the next recursive call.
//
// For asynchronous stores then each call happens on a separate native
// stack. This means that the previous stack limit is no longer relevant
// because we're on a separate stack.
if unsafe { *store.0.vm_store_context().stack_limit.get() } != usize::MAX
&& !store.0.async_support()
{
stack_limit = None;
}
// Ignore this stack pointer business on miri since we can't execute wasm
// anyway and the concept of a stack pointer on miri is a bit nebulous
// regardless.
else if cfg!(miri) {
stack_limit = None;
} else {
// When Cranelift has support for the host then we might be running native
// compiled code meaning we need to read the actual stack pointer. If
// Cranelift can't be used though then we're guaranteed to be running pulley
// in which case this stack pointer isn't actually used as Pulley has custom
// mechanisms for stack overflow.
#[cfg(has_host_compiler_backend)]
let stack_pointer = crate::runtime::vm::get_stack_pointer();
#[cfg(not(has_host_compiler_backend))]
let stack_pointer = {
use wasmtime_environ::TripleExt;
debug_assert!(store.engine().target().is_pulley());
usize::MAX
};

fn exit_wasm<T>(store: &mut StoreContextMut<'_, T>, prev_stack: Option<usize>) {
// If we don't have a previous stack pointer to restore, then there's no
// cleanup we need to perform here.
let prev_stack = match prev_stack {
Some(stack) => stack,
None => return,
};
// Determine the stack pointer where, after which, any wasm code will
// immediately trap. This is checked on the entry to all wasm functions.
//
// Note that this isn't 100% precise. We are requested to give wasm
// `max_wasm_stack` bytes, but what we're actually doing is giving wasm
// probably a little less than `max_wasm_stack` because we're
// calculating the limit relative to this function's approximate stack
// pointer. Wasm will be executed on a frame beneath this one (or next
// to it). In any case it's expected to be at most a few hundred bytes
// of slop one way or another. When wasm is typically given a MB or so
// (a million bytes) the slop shouldn't matter too much.
//
// After we've got the stack limit then we store it into the `stack_limit`
// variable.
let wasm_stack_limit = stack_pointer
.checked_sub(store.engine().config().max_wasm_stack)
.unwrap();
let prev_stack = unsafe {
mem::replace(
&mut *store.0.vm_store_context().stack_limit.get(),
wasm_stack_limit,
)
};
stack_limit = Some(prev_stack);
}

unsafe {
*store.0.vm_store_context().stack_limit.get() = prev_stack;
unsafe {
let last_wasm_exit_pc = *store.0.vm_store_context().last_wasm_exit_pc.get();
let last_wasm_exit_fp = *store.0.vm_store_context().last_wasm_exit_fp.get();
let last_wasm_entry_fp = *store.0.vm_store_context().last_wasm_entry_fp.get();

let vm_store_context = store.0.vm_store_context();

Self {
stack_limit,
last_wasm_exit_pc,
last_wasm_exit_fp,
last_wasm_entry_fp,
vm_store_context,
}
}
}

/// This function restores the values stored in this struct. We invoke this
/// function through this type's `Drop` implementation. This ensures that we
/// even restore the values if we unwind the stack (e.g., because we are
/// panicing out of a Wasm execution).
fn exit_wasm(&mut self) {
unsafe {
if let Some(limit) = self.stack_limit {
*(&*self.vm_store_context).stack_limit.get() = limit;
}

*(*self.vm_store_context).last_wasm_exit_fp.get() = self.last_wasm_exit_fp;
*(*self.vm_store_context).last_wasm_exit_pc.get() = self.last_wasm_exit_pc;
*(*self.vm_store_context).last_wasm_entry_fp.get() = self.last_wasm_entry_fp;
}
}
}

impl Drop for EntryStoreContext {
fn drop(&mut self) {
self.exit_wasm();
}
}

Expand Down
60 changes: 24 additions & 36 deletions crates/wasmtime/src/runtime/vm/traphandlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ mod signals;
#[cfg(all(has_native_signals))]
pub use self::signals::*;

use crate::prelude::*;
use crate::runtime::module::lookup_code;
use crate::runtime::store::{ExecutorRef, StoreOpaque};
use crate::runtime::vm::sys::traphandlers;
use crate::runtime::vm::{InterpreterRef, VMContext, VMStoreContext};
use crate::{prelude::*, EntryStoreContext};
use crate::{StoreContextMut, WasmBacktrace};
use core::cell::Cell;
use core::num::NonZeroU32;
Expand Down Expand Up @@ -365,14 +365,15 @@ impl From<wasmtime_environ::Trap> for TrapReason {
/// longjmp'd over and none of its destructors on the stack may be run.
pub unsafe fn catch_traps<T, F>(
store: &mut StoreContextMut<'_, T>,
old_state: &EntryStoreContext,
mut closure: F,
) -> Result<(), Box<Trap>>
where
F: FnMut(NonNull<VMContext>, Option<InterpreterRef<'_>>) -> bool,
{
let caller = store.0.default_caller();

let result = CallThreadState::new(store.0).with(|cx| match store.0.executor() {
let result = CallThreadState::new(store.0, old_state).with(|cx| match store.0.executor() {
// In interpreted mode directly invoke the host closure since we won't
// be using host-based `setjmp`/`longjmp` as that's not going to save
// the context we want.
Expand Down Expand Up @@ -424,6 +425,7 @@ where
mod call_thread_state {
use super::*;
use crate::runtime::vm::Unwind;
use crate::EntryStoreContext;

/// Temporary state stored on the stack which is registered in the `tls`
/// module below for calls into wasm.
Expand Down Expand Up @@ -462,39 +464,33 @@ mod call_thread_state {
#[cfg(all(has_native_signals, unix))]
pub(crate) async_guard_range: Range<*mut u8>,

// The values of `VMStoreContext::last_wasm_{exit_{pc,fp},entry_sp}` for
// the *previous* `CallThreadState` for this same store/limits. Our
// *current* last wasm PC/FP/SP are saved in `self.vm_store_context`. We
// save a copy of the old registers here because the `VMStoreContext`
// typically doesn't change across nested calls into Wasm (i.e. they are
// typically calls back into the same store and `self.vm_store_context
// == self.prev.vm_store_context`) and we must to maintain the list of
// contiguous-Wasm-frames stack regions for backtracing purposes.
old_last_wasm_exit_fp: Cell<usize>,
old_last_wasm_exit_pc: Cell<usize>,
old_last_wasm_entry_fp: Cell<usize>,
// The state of the runtime for the *previous* `CallThreadState` for
// this same store. Our *current* state is saved in `self.vm_store_context`,
// etc. We need access to the old values of these
// fields because the `VMStoreContext` typically doesn't change across
// nested calls into Wasm (i.e. they are typically calls back into the
// same store and `self.vm_store_context == self.prev.vm_store_context`) and we must to
// maintain the list of contiguous-Wasm-frames stack regions for
// backtracing purposes.
old_state: *const EntryStoreContext,
}

impl Drop for CallThreadState {
fn drop(&mut self) {
// Unwind information should not be present as it should have
// already been processed.
debug_assert!(self.unwind.replace(None).is_none());

unsafe {
let cx = self.vm_store_context.as_ref();
*cx.last_wasm_exit_fp.get() = self.old_last_wasm_exit_fp.get();
*cx.last_wasm_exit_pc.get() = self.old_last_wasm_exit_pc.get();
*cx.last_wasm_entry_fp.get() = self.old_last_wasm_entry_fp.get();
}
}
}

impl CallThreadState {
pub const JMP_BUF_INTERPRETER_SENTINEL: *mut u8 = 1 as *mut u8;

#[inline]
pub(super) fn new(store: &mut StoreOpaque) -> CallThreadState {
pub(super) fn new(
store: &mut StoreOpaque,
old_state: *const EntryStoreContext,
) -> CallThreadState {
// Don't try to plumb #[cfg] everywhere for this field, just pretend
// we're using it on miri/windows to silence compiler warnings.
let _: Range<_> = store.async_guard_range();
Expand All @@ -512,31 +508,23 @@ mod call_thread_state {
#[cfg(all(has_native_signals, unix))]
async_guard_range: store.async_guard_range(),
prev: Cell::new(ptr::null()),
old_last_wasm_exit_fp: Cell::new(unsafe {
*store.vm_store_context().last_wasm_exit_fp.get()
}),
old_last_wasm_exit_pc: Cell::new(unsafe {
*store.vm_store_context().last_wasm_exit_pc.get()
}),
old_last_wasm_entry_fp: Cell::new(unsafe {
*store.vm_store_context().last_wasm_entry_fp.get()
}),
old_state,
}
}

/// Get the saved FP upon exit from Wasm for the previous `CallThreadState`.
pub fn old_last_wasm_exit_fp(&self) -> usize {
self.old_last_wasm_exit_fp.get()
pub unsafe fn old_last_wasm_exit_fp(&self) -> usize {
(&*self.old_state).last_wasm_exit_fp
}

/// Get the saved PC upon exit from Wasm for the previous `CallThreadState`.
pub fn old_last_wasm_exit_pc(&self) -> usize {
self.old_last_wasm_exit_pc.get()
pub unsafe fn old_last_wasm_exit_pc(&self) -> usize {
(&*self.old_state).last_wasm_exit_pc
}

/// Get the saved FP upon entry into Wasm for the previous `CallThreadState`.
pub fn old_last_wasm_entry_fp(&self) -> usize {
self.old_last_wasm_entry_fp.get()
pub unsafe fn old_last_wasm_entry_fp(&self) -> usize {
(&*self.old_state).last_wasm_entry_fp
}

/// Get the previous `CallThreadState`.
Expand Down