Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve error handling #116

Merged
merged 16 commits into from
Feb 21, 2025
4 changes: 4 additions & 0 deletions src/backends/fallback.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#[inline(always)]
pub unsafe fn guess_os_stack_limit() -> Option<usize> {
None
}
6 changes: 6 additions & 0 deletions src/backends/macos.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pub unsafe fn guess_os_stack_limit() -> Option<usize> {
Some(
libc::pthread_get_stackaddr_np(libc::pthread_self()) as usize
- libc::pthread_get_stacksize_np(libc::pthread_self()) as usize,
)
}
28 changes: 28 additions & 0 deletions src/backends/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
cfg_if! {
if #[cfg(miri)] {
mod fallback;
pub use fallback::guess_os_stack_limit;
} else if #[cfg(windows)] {
pub(crate) mod windows;
pub use windows::guess_os_stack_limit;
} else if #[cfg(any(
target_os = "linux",
target_os = "solaris",
target_os = "netbsd",
target_os = "freebsd",
target_os = "dragonfly",
target_os = "illumos"
))] {
mod unix;
pub use unix::guess_os_stack_limit;
} else if #[cfg(target_os = "openbsd")] {
mod openbsd;
pub use openbsd::guess_os_stack_limit;
} else if #[cfg(target_os = "macos")] {
mod macos;
pub use macos::guess_os_stack_limit;
} else {
mod fallback;
pub use fallback::guess_os_stack_limit;
}
}
8 changes: 8 additions & 0 deletions src/backends/openbsd.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
pub unsafe fn guess_os_stack_limit() -> Option<usize> {
let mut stackinfo = std::mem::MaybeUninit::<libc::stack_t>::uninit();
let res = libc::pthread_stackseg_np(libc::pthread_self(), stackinfo.as_mut_ptr());
if res != 0 {
return None;
}
Some(stackinfo.assume_init().ss_sp as usize - stackinfo.assume_init().ss_size)
}
58 changes: 58 additions & 0 deletions src/backends/unix.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#[cfg(any(target_os = "freebsd", target_os = "dragonfly", target_os = "illumos"))]
use libc::pthread_attr_get_np as get_attr;
#[cfg(any(target_os = "linux", target_os = "solaris", target_os = "netbsd"))]
use libc::pthread_getattr_np as get_attr;

pub unsafe fn guess_os_stack_limit() -> Option<usize> {
let mut attr = PthreadAttr::new()?;

handle_pthread_err(get_attr(libc::pthread_self(), attr.as_mut_ptr()))?;

let mut stackaddr = std::ptr::null_mut();
let mut stacksize = 0;
handle_pthread_err(libc::pthread_attr_getstack(
attr.as_mut_ptr(),
&mut stackaddr,
&mut stacksize,
))?;

Some(stackaddr as usize)
}

struct PthreadAttr(std::mem::MaybeUninit<libc::pthread_attr_t>);

impl Drop for PthreadAttr {
fn drop(&mut self) {
unsafe {
let ret = libc::pthread_attr_destroy(self.0.as_mut_ptr());
if ret != 0 {
let err = std::io::Error::last_os_error();
panic!(
"pthread_attr_destroy failed with error code {}: {}",
ret, err
);
}
}
}
}

fn handle_pthread_err(ret: libc::c_int) -> Option<()> {
if ret != 0 {
return None;
}
Some(())
}

impl PthreadAttr {
unsafe fn new() -> Option<Self> {
let mut attr = std::mem::MaybeUninit::<libc::pthread_attr_t>::uninit();
if libc::pthread_attr_init(attr.as_mut_ptr()) != 0 {
return None;
}
Some(PthreadAttr(attr))
}

fn as_mut_ptr(&mut self) -> *mut libc::pthread_attr_t {
self.0.as_mut_ptr()
}
}
142 changes: 142 additions & 0 deletions src/backends/windows.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
use libc::c_void;
use std::io;
use std::ptr;
use windows_sys::Win32::Foundation::BOOL;
use windows_sys::Win32::System::Memory::VirtualQuery;
use windows_sys::Win32::System::Threading::{
ConvertFiberToThread, ConvertThreadToFiber, CreateFiber, DeleteFiber, IsThreadAFiber,
SetThreadStackGuarantee, SwitchToFiber,
};

// Make sure the libstacker.a (implemented in C) is linked.
// See https://github.com/rust-lang/rust/issues/65610
#[link(name = "stacker")]
extern "C" {
fn __stacker_get_current_fiber() -> *mut c_void;
}

struct FiberInfo<F> {
callback: std::mem::MaybeUninit<F>,
panic: Option<Box<dyn std::any::Any + Send + 'static>>,
parent_fiber: *mut c_void,
}

unsafe extern "system" fn fiber_proc<F: FnOnce()>(data: *mut c_void) {
// This function is the entry point to our inner fiber, and as argument we get an
// instance of `FiberInfo`. We will set-up the "runtime" for the callback and execute
// it.
let data = &mut *(data as *mut FiberInfo<F>);
let old_stack_limit = crate::get_stack_limit();
crate::set_stack_limit(guess_os_stack_limit());
let callback = data.callback.as_ptr();
data.panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(callback.read())).err();

// Restore to the previous Fiber
crate::set_stack_limit(old_stack_limit);
SwitchToFiber(data.parent_fiber);
}

pub fn _grow(stack_size: usize, callback: &mut dyn FnMut()) {
// Fibers (or stackful coroutines) is the only official way to create new stacks on the
// same thread on Windows. So in order to extend the stack we create fiber and switch
// to it so we can use it's stack. After running `callback` within our fiber, we switch
// back to the current stack and destroy the fiber and its associated stack.
unsafe {
let was_fiber = IsThreadAFiber() == 1 as BOOL;
let mut data = FiberInfo {
callback: std::mem::MaybeUninit::new(callback),
panic: None,
parent_fiber: {
if was_fiber {
// Get a handle to the current fiber. We need to use a C implementation
// for this as GetCurrentFiber is an header only function.
__stacker_get_current_fiber()
} else {
// Convert the current thread to a fiber, so we are able to switch back
// to the current stack. Threads coverted to fibers still act like
// regular threads, but they have associated fiber data. We later
// convert it back to a regular thread and free the fiber data.
ConvertThreadToFiber(ptr::null_mut())
}
},
};

if data.parent_fiber.is_null() {
panic!(
"unable to convert thread to fiber: {}",
io::Error::last_os_error()
);
}

let fiber = CreateFiber(
stack_size as usize,
Some(fiber_proc::<&mut dyn FnMut()>),
&mut data as *mut FiberInfo<&mut dyn FnMut()> as *mut _,
);
if fiber.is_null() {
panic!("unable to allocate fiber: {}", io::Error::last_os_error());
}

// Switch to the fiber we created. This changes stacks and starts executing
// fiber_proc on it. fiber_proc will run `callback` and then switch back to run the
// next statement.
SwitchToFiber(fiber);
DeleteFiber(fiber);

// Clean-up.
if !was_fiber && ConvertFiberToThread() == 0 {
// FIXME: Perhaps should not panic here?
panic!(
"unable to convert back to thread: {}",
io::Error::last_os_error()
);
}

if let Some(p) = data.panic {
std::panic::resume_unwind(p);
}
}
}

#[inline(always)]
fn get_thread_stack_guarantee() -> Option<usize> {
let min_guarantee = if cfg!(target_pointer_width = "32") {
0x1000
} else {
0x2000
};
let mut stack_guarantee = 0;
unsafe {
// Read the current thread stack guarantee
// This is the stack reserved for stack overflow
// exception handling.
// This doesn't return the true value so we need
// some further logic to calculate the real stack
// guarantee. This logic is what is used on x86-32 and
// x86-64 Windows 10. Other versions and platforms may differ
let ret = SetThreadStackGuarantee(&mut stack_guarantee);
if ret == 0 {
return None;
}
};
Some(std::cmp::max(stack_guarantee, min_guarantee) as usize + 0x1000)
}

#[inline(always)]
pub unsafe fn guess_os_stack_limit() -> Option<usize> {
// Query the allocation which contains our stack pointer in order
// to discover the size of the stack
//
// FIXME: we could read stack base from the TIB, specifically the 3rd element of it.
type QueryT = windows_sys::Win32::System::Memory::MEMORY_BASIC_INFORMATION;
let mut mi = std::mem::MaybeUninit::<QueryT>::uninit();
let res = VirtualQuery(
psm::stack_pointer() as *const _,
mi.as_mut_ptr(),
std::mem::size_of::<QueryT>() as usize,
);
if res == 0 {
return None;
}
Some(mi.assume_init().AllocationBase as usize + get_thread_stack_guarantee()? + 0x1000)
}
Loading
Loading