diff --git a/src/guess_os_stack_limit/fallback.rs b/src/backends/fallback.rs similarity index 100% rename from src/guess_os_stack_limit/fallback.rs rename to src/backends/fallback.rs diff --git a/src/guess_os_stack_limit/macos.rs b/src/backends/macos.rs similarity index 100% rename from src/guess_os_stack_limit/macos.rs rename to src/backends/macos.rs diff --git a/src/guess_os_stack_limit/mod.rs b/src/backends/mod.rs similarity index 96% rename from src/guess_os_stack_limit/mod.rs rename to src/backends/mod.rs index 7b14a34..2821cfa 100644 --- a/src/guess_os_stack_limit/mod.rs +++ b/src/backends/mod.rs @@ -3,7 +3,7 @@ cfg_if! { mod fallback; pub use fallback::guess_os_stack_limit; } else if #[cfg(windows)] { - mod windows; + pub(crate) mod windows; pub use windows::guess_os_stack_limit; } else if #[cfg(any( target_os = "linux", diff --git a/src/guess_os_stack_limit/openbsd.rs b/src/backends/openbsd.rs similarity index 100% rename from src/guess_os_stack_limit/openbsd.rs rename to src/backends/openbsd.rs diff --git a/src/guess_os_stack_limit/unix.rs b/src/backends/unix.rs similarity index 100% rename from src/guess_os_stack_limit/unix.rs rename to src/backends/unix.rs diff --git a/src/guess_os_stack_limit/unix_pthread_wrapper.rs b/src/backends/unix_pthread_wrapper.rs similarity index 100% rename from src/guess_os_stack_limit/unix_pthread_wrapper.rs rename to src/backends/unix_pthread_wrapper.rs diff --git a/src/backends/windows.rs b/src/backends/windows.rs new file mode 100644 index 0000000..401f6fc --- /dev/null +++ b/src/backends/windows.rs @@ -0,0 +1,136 @@ +use libc::c_void; +use std::io; +use std::ptr; +use windows_sys::Win32::Foundation::BOOL; +use windows_sys::Win32::System::Memory::VirtualQuery; +use windows_sys::Win32::System::Threading::{ + ConvertFiberToThread, ConvertThreadToFiber, CreateFiber, DeleteFiber, IsThreadAFiber, + SetThreadStackGuarantee, SwitchToFiber, +}; + +// Make sure the libstacker.a (implemented in C) is linked. +// See https://github.com/rust-lang/rust/issues/65610 +#[link(name = "stacker")] +extern "C" { + fn __stacker_get_current_fiber() -> *mut c_void; +} + +struct FiberInfo { + callback: std::mem::MaybeUninit, + panic: Option>, + parent_fiber: *mut c_void, +} + +unsafe extern "system" fn fiber_proc(data: *mut c_void) { + // This function is the entry point to our inner fiber, and as argument we get an + // instance of `FiberInfo`. We will set-up the "runtime" for the callback and execute + // it. + let data = &mut *(data as *mut FiberInfo); + let old_stack_limit = crate::get_stack_limit(); + crate::set_stack_limit(guess_os_stack_limit()); + let callback = data.callback.as_ptr(); + data.panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(callback.read())).err(); + + // Restore to the previous Fiber + crate::set_stack_limit(old_stack_limit); + SwitchToFiber(data.parent_fiber); +} + +pub fn _grow(stack_size: usize, callback: &mut dyn FnMut()) { + // Fibers (or stackful coroutines) is the only official way to create new stacks on the + // same thread on Windows. So in order to extend the stack we create fiber and switch + // to it so we can use it's stack. After running `callback` within our fiber, we switch + // back to the current stack and destroy the fiber and its associated stack. + unsafe { + let was_fiber = IsThreadAFiber() == 1 as BOOL; + let mut data = FiberInfo { + callback: std::mem::MaybeUninit::new(callback), + panic: None, + parent_fiber: { + if was_fiber { + // Get a handle to the current fiber. We need to use a C implementation + // for this as GetCurrentFiber is an header only function. + __stacker_get_current_fiber() + } else { + // Convert the current thread to a fiber, so we are able to switch back + // to the current stack. Threads coverted to fibers still act like + // regular threads, but they have associated fiber data. We later + // convert it back to a regular thread and free the fiber data. + ConvertThreadToFiber(ptr::null_mut()) + } + }, + }; + + if data.parent_fiber.is_null() { + panic!( + "unable to convert thread to fiber: {}", + io::Error::last_os_error() + ); + } + + let fiber = CreateFiber( + stack_size as usize, + Some(fiber_proc::<&mut dyn FnMut()>), + &mut data as *mut FiberInfo<&mut dyn FnMut()> as *mut _, + ); + if fiber.is_null() { + panic!("unable to allocate fiber: {}", io::Error::last_os_error()); + } + + // Switch to the fiber we created. This changes stacks and starts executing + // fiber_proc on it. fiber_proc will run `callback` and then switch back to run the + // next statement. + SwitchToFiber(fiber); + DeleteFiber(fiber); + + // Clean-up. + if !was_fiber && ConvertFiberToThread() == 0 { + // FIXME: Perhaps should not panic here? + panic!( + "unable to convert back to thread: {}", + io::Error::last_os_error() + ); + } + + if let Some(p) = data.panic { + std::panic::resume_unwind(p); + } + } +} + +#[inline(always)] +fn get_thread_stack_guarantee() -> usize { + let min_guarantee = if cfg!(target_pointer_width = "32") { + 0x1000 + } else { + 0x2000 + }; + let mut stack_guarantee = 0; + unsafe { + // Read the current thread stack guarantee + // This is the stack reserved for stack overflow + // exception handling. + // This doesn't return the true value so we need + // some further logic to calculate the real stack + // guarantee. This logic is what is used on x86-32 and + // x86-64 Windows 10. Other versions and platforms may differ + SetThreadStackGuarantee(&mut stack_guarantee) + }; + std::cmp::max(stack_guarantee, min_guarantee) as usize + 0x1000 +} + +#[inline(always)] +pub unsafe fn guess_os_stack_limit() -> Option { + // Query the allocation which contains our stack pointer in order + // to discover the size of the stack + // + // FIXME: we could read stack base from the TIB, specifically the 3rd element of it. + type QueryT = windows_sys::Win32::System::Memory::MEMORY_BASIC_INFORMATION; + let mut mi = std::mem::MaybeUninit::::uninit(); + VirtualQuery( + psm::stack_pointer() as *const _, + mi.as_mut_ptr(), + std::mem::size_of::() as usize, + ); + Some(mi.assume_init().AllocationBase as usize + get_thread_stack_guarantee() + 0x1000) +} diff --git a/src/guess_os_stack_limit/windows.rs b/src/guess_os_stack_limit/windows.rs deleted file mode 100644 index 51a05e5..0000000 --- a/src/guess_os_stack_limit/windows.rs +++ /dev/null @@ -1,30 +0,0 @@ -use libc::c_void; -use std::io; -use std::ptr; -use windows_sys::Win32::Foundation::BOOL; -use windows_sys::Win32::System::Memory::VirtualQuery; -use windows_sys::Win32::System::Threading::*; - -#[inline(always)] -fn get_thread_stack_guarantee() -> usize { - let min_guarantee = if cfg!(target_pointer_width = "32") { - 0x1000 - } else { - 0x2000 - }; - let mut stack_guarantee = 0; - unsafe { SetThreadStackGuarantee(&mut stack_guarantee) }; - std::cmp::max(stack_guarantee, min_guarantee) as usize + 0x1000 -} - -#[inline(always)] -pub unsafe fn guess_os_stack_limit() -> Option { - type QueryT = windows_sys::Win32::System::Memory::MEMORY_BASIC_INFORMATION; - let mut mi = std::mem::MaybeUninit::::uninit(); - VirtualQuery( - psm::stack_pointer() as *const _, - mi.as_mut_ptr(), - std::mem::size_of::() as usize, - ); - Some(mi.assume_init().AllocationBase as usize + get_thread_stack_guarantee() + 0x1000) -} diff --git a/src/lib.rs b/src/lib.rs index 0277883..3f98ce8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,7 +32,7 @@ extern crate windows_sys; #[macro_use] extern crate psm; -mod guess_os_stack_limit; +mod backends; use std::cell::Cell; @@ -118,7 +118,7 @@ psm_stack_information!( thread_local! { static STACK_LIMIT: Cell> = Cell::new(unsafe { - guess_os_stack_limit::guess_os_stack_limit() + backends::guess_os_stack_limit() }) } @@ -276,5 +276,7 @@ psm_stack_manipulation! { let _ = stack_size; callback(); } + #[cfg(windows)] + use backends::windows::_grow; } }