diff --git a/mlua-sys/src/lua52/lua.rs b/mlua-sys/src/lua52/lua.rs index e5239cee..8de7cdee 100644 --- a/mlua-sys/src/lua52/lua.rs +++ b/mlua-sys/src/lua52/lua.rs @@ -272,6 +272,11 @@ pub unsafe fn lua_yield(L: *mut lua_State, n: c_int) -> c_int { lua_yieldk(L, n, 0, None) } +#[inline(always)] +pub unsafe fn lua_yieldc(L: *mut lua_State, n: c_int, k: lua_CFunction) -> c_int { + lua_yieldk(L, n, 0, Some(k)) +} + // // Garbage-collection function and options // diff --git a/mlua-sys/src/lua53/lua.rs b/mlua-sys/src/lua53/lua.rs index 2729fdcd..c3d82e63 100644 --- a/mlua-sys/src/lua53/lua.rs +++ b/mlua-sys/src/lua53/lua.rs @@ -286,6 +286,11 @@ pub unsafe fn lua_yield(L: *mut lua_State, n: c_int) -> c_int { lua_yieldk(L, n, 0, None) } +#[inline(always)] +pub unsafe fn lua_yieldc(L: *mut lua_State, n: c_int, k: lua_KFunction) -> c_int { + lua_yieldk(L, n, 0, Some(k)) +} + // // Garbage-collection function and options // diff --git a/mlua-sys/src/lua54/lua.rs b/mlua-sys/src/lua54/lua.rs index 15a30444..c74e1576 100644 --- a/mlua-sys/src/lua54/lua.rs +++ b/mlua-sys/src/lua54/lua.rs @@ -299,6 +299,11 @@ pub unsafe fn lua_yield(L: *mut lua_State, n: c_int) -> c_int { lua_yieldk(L, n, 0, None) } +#[inline(always)] +pub unsafe fn lua_yieldc(L: *mut lua_State, n: c_int, k: lua_KFunction) -> c_int { + lua_yieldk(L, n, 0, Some(k)) +} + // // Warning-related functions // diff --git a/mlua-sys/src/luau/lua.rs b/mlua-sys/src/luau/lua.rs index 8a55eef1..d898534c 100644 --- a/mlua-sys/src/luau/lua.rs +++ b/mlua-sys/src/luau/lua.rs @@ -426,6 +426,11 @@ pub unsafe fn lua_pushcclosure(L: *mut lua_State, f: lua_CFunction, nup: c_int) lua_pushcclosurek(L, f, ptr::null(), nup, None) } +#[inline(always)] +pub unsafe fn lua_pushcclosurec(L: *mut lua_State, f: lua_CFunction, cont: lua_Continuation, nup: c_int) { + lua_pushcclosurek(L, f, ptr::null(), nup, Some(cont)) +} + #[inline(always)] pub unsafe fn lua_pushcclosured(L: *mut lua_State, f: lua_CFunction, debugname: *const c_char, nup: c_int) { lua_pushcclosurek(L, f, debugname, nup, None) diff --git a/src/error.rs b/src/error.rs index c20a39ef..2d369109 100644 --- a/src/error.rs +++ b/src/error.rs @@ -321,7 +321,7 @@ impl fmt::Display for Error { Error::WithContext { context, cause } => { writeln!(fmt, "{context}")?; write!(fmt, "{cause}") - } + }, } } } diff --git a/src/lib.rs b/src/lib.rs index e1589ce6..8c7668ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -110,7 +110,7 @@ pub use crate::state::{GCMode, Lua, LuaOptions, WeakLua}; pub use crate::stdlib::StdLib; pub use crate::string::{BorrowedBytes, BorrowedStr, String}; pub use crate::table::{Table, TablePairs, TableSequence}; -pub use crate::thread::{Thread, ThreadStatus}; +pub use crate::thread::{ContinuationStatus, Thread, ThreadStatus}; pub use crate::traits::{ FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, LuaNativeFn, LuaNativeFnMut, ObjectLike, }; diff --git a/src/prelude.rs b/src/prelude.rs index a3a03201..c3b5d443 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -3,15 +3,15 @@ #[doc(no_inline)] pub use crate::{ AnyUserData as LuaAnyUserData, BorrowedBytes as LuaBorrowedBytes, BorrowedStr as LuaBorrowedStr, - Chunk as LuaChunk, Either as LuaEither, Error as LuaError, ErrorContext as LuaErrorContext, - ExternalError as LuaExternalError, ExternalResult as LuaExternalResult, FromLua, FromLuaMulti, - Function as LuaFunction, FunctionInfo as LuaFunctionInfo, GCMode as LuaGCMode, Integer as LuaInteger, - IntoLua, IntoLuaMulti, LightUserData as LuaLightUserData, Lua, LuaNativeFn, LuaNativeFnMut, LuaOptions, - MetaMethod as LuaMetaMethod, MultiValue as LuaMultiValue, Nil as LuaNil, Number as LuaNumber, - ObjectLike as LuaObjectLike, RegistryKey as LuaRegistryKey, Result as LuaResult, StdLib as LuaStdLib, - String as LuaString, Table as LuaTable, TablePairs as LuaTablePairs, TableSequence as LuaTableSequence, - Thread as LuaThread, ThreadStatus as LuaThreadStatus, UserData as LuaUserData, - UserDataFields as LuaUserDataFields, UserDataMetatable as LuaUserDataMetatable, + Chunk as LuaChunk, ContinuationStatus as LuaContinuationStatus, Either as LuaEither, Error as LuaError, + ErrorContext as LuaErrorContext, ExternalError as LuaExternalError, ExternalResult as LuaExternalResult, + FromLua, FromLuaMulti, Function as LuaFunction, FunctionInfo as LuaFunctionInfo, GCMode as LuaGCMode, + Integer as LuaInteger, IntoLua, IntoLuaMulti, LightUserData as LuaLightUserData, Lua, LuaNativeFn, + LuaNativeFnMut, LuaOptions, MetaMethod as LuaMetaMethod, MultiValue as LuaMultiValue, Nil as LuaNil, + Number as LuaNumber, ObjectLike as LuaObjectLike, RegistryKey as LuaRegistryKey, Result as LuaResult, + StdLib as LuaStdLib, String as LuaString, Table as LuaTable, TablePairs as LuaTablePairs, + TableSequence as LuaTableSequence, Thread as LuaThread, ThreadStatus as LuaThreadStatus, + UserData as LuaUserData, UserDataFields as LuaUserDataFields, UserDataMetatable as LuaUserDataMetatable, UserDataMethods as LuaUserDataMethods, UserDataRef as LuaUserDataRef, UserDataRefMut as LuaUserDataRefMut, UserDataRegistry as LuaUserDataRegistry, Value as LuaValue, Variadic as LuaVariadic, VmState as LuaVmState, WeakLua, diff --git a/src/state.rs b/src/state.rs index dbe4c2b2..9b006f9c 100644 --- a/src/state.rs +++ b/src/state.rs @@ -18,6 +18,10 @@ use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; use crate::thread::Thread; + +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +use crate::thread::ContinuationStatus; + use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; use crate::types::{ AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LuaType, MaybeSend, Number, ReentrantMutex, @@ -1265,6 +1269,44 @@ impl Lua { })) } + /// Same as ``create_function`` but with an added continuation function. + /// + /// The values passed to the continuation will be the yielded arguments + /// from the function for the initial continuation call. If yielding from a + /// continuation, the yielded results will be returned to the ``Thread::resume`` caller. The + /// arguments passed in the next ``Thread::resume`` call will then be the arguments passed + /// to the yielding continuation upon resumption. + /// + /// Returning a value from a continuation without setting yield + /// arguments will then be returned as the final return value of the Lua function call. + /// Values returned in a function in which there is also yielding will be ignored + #[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] + pub fn create_function_with_continuation( + &self, + func: F, + cont: FC, + ) -> Result + where + F: Fn(&Lua, A) -> Result + MaybeSend + 'static, + FC: Fn(&Lua, ContinuationStatus, AC) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + AC: FromLuaMulti, + R: IntoLuaMulti, + RC: IntoLuaMulti, + { + (self.lock()).create_callback_with_continuation( + Box::new(move |rawlua, nargs| unsafe { + let args = A::from_stack_args(nargs, 1, None, rawlua)?; + func(rawlua.lua(), args)?.push_into_stack_multi(rawlua) + }), + Box::new(move |rawlua, nargs, status| unsafe { + let args = AC::from_stack_args(nargs, 1, None, rawlua)?; + let status = ContinuationStatus::from_status(status); + cont(rawlua.lua(), status, args)?.push_into_stack_multi(rawlua) + }), + ) + } + /// Wraps a Rust mutable closure, creating a callable Lua function handle to it. /// /// This is a version of [`Lua::create_function`] that accepts a `FnMut` argument. @@ -2080,6 +2122,42 @@ impl Lua { pub(crate) unsafe fn raw_lua(&self) -> &RawLua { &*self.raw.data_ptr() } + + /// Set the yield arguments. Note that Lua will not yield until you return from the function + /// + /// This method is mostly useful with continuations and Rust-Rust yields + /// due to the Rust/Lua boundary + /// + /// Example: + /// + /// ```rust + /// fn test() -> mlua::Result<()> { + /// let lua = mlua::Lua::new(); + /// let always_yield = lua.create_function(|lua, ()| lua.yield_with((42, "69420".to_string(), 45.6)))?; + /// + /// let thread = lua.create_thread(always_yield)?; + /// assert_eq!( + /// thread.resume::<(i32, String, f32)>(())?, + /// (42, String::from("69420"), 45.6) + /// ); + /// + /// Ok(()) + /// } + /// ``` + pub fn yield_with(&self, args: impl IntoLuaMulti) -> Result<()> { + let raw = self.lock(); + unsafe { + raw.extra.get().as_mut().unwrap_unchecked().yielded_values = Some(args.into_lua_multi(self)?); + } + Ok(()) + } + + /// Checks if Lua is be allowed to yield. + #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] + #[inline] + pub fn is_yieldable(&self) -> bool { + self.lock().is_yieldable() + } } impl WeakLua { diff --git a/src/state/extra.rs b/src/state/extra.rs index 5ff74a33..95455f09 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -13,11 +13,13 @@ use crate::error::Result; use crate::state::RawLua; use crate::stdlib::StdLib; use crate::types::{AppData, ReentrantMutex, XRc}; + use crate::userdata::RawUserDataRegistry; use crate::util::{get_internal_metatable, push_internal_userdata, TypeKey, WrappedFailure}; #[cfg(any(feature = "luau", doc))] use crate::chunk::Compiler; +use crate::MultiValue; #[cfg(feature = "async")] use {futures_util::task::noop_waker_ref, std::ptr::NonNull, std::task::Waker}; @@ -94,6 +96,9 @@ pub(crate) struct ExtraData { pub(super) compiler: Option, #[cfg(feature = "luau-jit")] pub(super) enable_jit: bool, + + // Values currently being yielded from Lua.yield() + pub(super) yielded_values: Option, } impl Drop for ExtraData { @@ -196,6 +201,7 @@ impl ExtraData { enable_jit: true, #[cfg(feature = "luau")] running_gc: false, + yielded_values: None, })); // Store it in the registry diff --git a/src/state/raw.rs b/src/state/raw.rs index b7de97f2..1f9bb08f 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -11,7 +11,9 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -use crate::state::util::{callback_error_ext, ref_stack_pop}; +#[allow(unused_imports)] +use crate::state::util::callback_error_ext; +use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; @@ -21,6 +23,12 @@ use crate::types::{ AppDataRef, AppDataRefMut, Callback, CallbackUpvalue, DestructedUserdata, Integer, LightUserData, MaybeSend, ReentrantMutex, RegistryKey, ValueRef, XRc, }; + +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +use crate::types::Continuation; +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +use crate::types::ContinuationUpvalue; + use crate::userdata::{ init_userdata_metatable, AnyUserData, MetaMethod, RawUserDataRegistry, UserData, UserDataRegistry, UserDataStorage, @@ -197,6 +205,8 @@ impl RawLua { init_internal_metatable::>>(state, None)?; init_internal_metatable::(state, None)?; init_internal_metatable::(state, None)?; + #[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] + init_internal_metatable::(state, None)?; #[cfg(not(feature = "luau"))] init_internal_metatable::(state, None)?; #[cfg(feature = "async")] @@ -1183,15 +1193,21 @@ impl RawLua { pub(crate) fn create_callback(&self, func: Callback) -> Result { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext(state, (*upvalue).extra.get(), true, |extra, nargs| { - // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) - // The lock must be already held as the callback is executed - let rawlua = (*extra).raw_lua(); - match (*upvalue).data { - Some(ref func) => func(rawlua, nargs), - None => Err(Error::CallbackDestructed), - } - }) + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => func(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }, + false, + ) } let state = self.state(); @@ -1215,6 +1231,124 @@ impl RawLua { } } + // Creates a Function out of a Callback and a continuation containing a 'static Fn. + // + // In Luau, uses pushcclosurek + // + // In Lua 5.2/5.3/5.4/JIT, makes a normal function that then yields to the continuation via yieldk + #[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] + pub(crate) fn create_callback_with_continuation( + &self, + func: Callback, + cont: Continuation, + ) -> Result { + #[cfg(feature = "luau")] + { + unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing + // arguments) The lock must be already held as the callback is + // executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.0)(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }, + true, + ) + } + + unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing + // arguments) The lock must be already held as the callback is + // executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.1)(rawlua, nargs, status), + None => Err(Error::CallbackDestructed), + } + }, + true, + ) + } + + let state = self.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 4)?; + + let func = Some((func, cont)); + let extra = XRc::clone(&self.extra); + let protect = !self.unlikely_memory_error(); + push_internal_userdata(state, ContinuationUpvalue { data: func, extra }, protect)?; + if protect { + protect_lua!(state, 1, 1, fn(state) { + ffi::lua_pushcclosurec(state, call_callback, cont_callback, 1); + })?; + } else { + ffi::lua_pushcclosurec(state, call_callback, cont_callback, 1); + } + + Ok(Function(self.pop_ref())) + } + } + + #[cfg(not(feature = "luau"))] + { + unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing + // arguments) The lock must be already held as the callback is + // executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some((ref func, _)) => func(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }, + true, + ) + } + + let state = self.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 4)?; + + let func = Some((func, cont)); + let extra = XRc::clone(&self.extra); + let protect = !self.unlikely_memory_error(); + push_internal_userdata(state, ContinuationUpvalue { data: func, extra }, protect)?; + if protect { + protect_lua!(state, 1, 1, fn(state) { + ffi::lua_pushcclosure(state, call_callback, 1); + })?; + } else { + ffi::lua_pushcclosure(state, call_callback, 1); + } + + Ok(Function(self.pop_ref())) + } + } + } + #[cfg(feature = "async")] pub(crate) fn create_async_callback(&self, func: AsyncCallback) -> Result { // Ensure that the coroutine library is loaded @@ -1375,6 +1509,12 @@ impl RawLua { pub(crate) unsafe fn set_waker(&self, waker: NonNull) -> NonNull { mem::replace(&mut (*self.extra.get()).waker, waker) } + + #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] + #[inline] + pub(crate) fn is_yieldable(&self) -> bool { + unsafe { ffi::lua_isyieldable(self.state()) != 0 } + } } // Uses 3 stack spaces diff --git a/src/state/util.rs b/src/state/util.rs index c3c79302..69edf40f 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -1,3 +1,5 @@ +use crate::IntoLuaMulti; +use std::mem::take; use std::os::raw::c_int; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr; @@ -5,7 +7,10 @@ use std::sync::Arc; use crate::error::{Error, Result}; use crate::state::{ExtraData, RawLua}; -use crate::util::{self, get_internal_metatable, WrappedFailure}; +use crate::util::{self, check_stack, get_internal_metatable, WrappedFailure}; + +#[cfg(all(not(feature = "lua51"), not(feature = "luajit"), not(feature = "luau")))] +use crate::{types::ContinuationUpvalue, util::get_userdata}; struct StateGuard<'a>(&'a RawLua, *mut ffi::lua_State); @@ -22,6 +27,65 @@ impl Drop for StateGuard<'_> { } } +pub(crate) enum PreallocatedFailure { + New(*mut WrappedFailure), + Reserved, +} + +impl PreallocatedFailure { + unsafe fn reserve(state: *mut ffi::lua_State, extra: *mut ExtraData) -> Self { + if (*extra).wrapped_failure_top > 0 { + (*extra).wrapped_failure_top -= 1; + return PreallocatedFailure::Reserved; + } + + // We need to check stack for Luau in case when callback is called from interrupt + // See https://github.com/luau-lang/luau/issues/446 and mlua #142 and #153 + #[cfg(feature = "luau")] + ffi::lua_rawcheckstack(state, 2); + // Place it to the beginning of the stack + let ud = WrappedFailure::new_userdata(state); + ffi::lua_insert(state, 1); + PreallocatedFailure::New(ud) + } + + #[cold] + unsafe fn r#use(&self, state: *mut ffi::lua_State, extra: *mut ExtraData) -> *mut WrappedFailure { + let ref_thread = (*extra).ref_thread; + match *self { + PreallocatedFailure::New(ud) => { + ffi::lua_settop(state, 1); + ud + } + PreallocatedFailure::Reserved => { + let index = (*extra).wrapped_failure_pool.pop().unwrap(); + ffi::lua_settop(state, 0); + #[cfg(feature = "luau")] + ffi::lua_rawcheckstack(state, 2); + ffi::lua_xpush(ref_thread, state, index); + ffi::lua_pushnil(ref_thread); + ffi::lua_replace(ref_thread, index); + (*extra).ref_free.push(index); + ffi::lua_touserdata(state, -1) as *mut WrappedFailure + } + } + } + + unsafe fn release(self, state: *mut ffi::lua_State, extra: *mut ExtraData) { + let ref_thread = (*extra).ref_thread; + match self { + PreallocatedFailure::New(_) => { + ffi::lua_rotate(state, 1, -1); + ffi::lua_xmove(state, ref_thread, 1); + let index = ref_stack_pop(extra); + (*extra).wrapped_failure_pool.push(index); + (*extra).wrapped_failure_top += 1; + } + PreallocatedFailure::Reserved => (*extra).wrapped_failure_top += 1, + } + } +} + // An optimized version of `callback_error` that does not allocate `WrappedFailure` userdata // and instead reuses unused values from previous calls (or allocates new). pub(crate) unsafe fn callback_error_ext( @@ -39,64 +103,78 @@ where let nargs = ffi::lua_gettop(state); - enum PreallocatedFailure { - New(*mut WrappedFailure), - Reserved, - } - - impl PreallocatedFailure { - unsafe fn reserve(state: *mut ffi::lua_State, extra: *mut ExtraData) -> Self { - if (*extra).wrapped_failure_top > 0 { - (*extra).wrapped_failure_top -= 1; - return PreallocatedFailure::Reserved; - } + // We cannot shadow Rust errors with Lua ones, so we need to reserve pre-allocated memory + // to store a wrapped failure (error or panic) *before* we proceed. + let prealloc_failure = PreallocatedFailure::reserve(state, extra); - // We need to check stack for Luau in case when callback is called from interrupt - // See https://github.com/luau-lang/luau/issues/446 and mlua #142 and #153 - #[cfg(feature = "luau")] - ffi::lua_rawcheckstack(state, 2); - // Place it to the beginning of the stack - let ud = WrappedFailure::new_userdata(state); - ffi::lua_insert(state, 1); - PreallocatedFailure::New(ud) + match catch_unwind(AssertUnwindSafe(|| { + let rawlua = (*extra).raw_lua(); + let _guard = StateGuard::new(rawlua, state); + f(extra, nargs) + })) { + Ok(Ok(r)) => { + // Return unused `WrappedFailure` to the pool + prealloc_failure.release(state, extra); + r } + Ok(Err(err)) => { + let wrapped_error = prealloc_failure.r#use(state, extra); - #[cold] - unsafe fn r#use(&self, state: *mut ffi::lua_State, extra: *mut ExtraData) -> *mut WrappedFailure { - let ref_thread = (*extra).ref_thread; - match *self { - PreallocatedFailure::New(ud) => { - ffi::lua_settop(state, 1); - ud - } - PreallocatedFailure::Reserved => { - let index = (*extra).wrapped_failure_pool.pop().unwrap(); - ffi::lua_settop(state, 0); - #[cfg(feature = "luau")] - ffi::lua_rawcheckstack(state, 2); - ffi::lua_xpush(ref_thread, state, index); - ffi::lua_pushnil(ref_thread); - ffi::lua_replace(ref_thread, index); - (*extra).ref_free.push(index); - ffi::lua_touserdata(state, -1) as *mut WrappedFailure - } + if !wrap_error { + ptr::write(wrapped_error, WrappedFailure::Error(err)); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state) } - } - unsafe fn release(self, state: *mut ffi::lua_State, extra: *mut ExtraData) { - let ref_thread = (*extra).ref_thread; - match self { - PreallocatedFailure::New(_) => { - ffi::lua_rotate(state, 1, -1); - ffi::lua_xmove(state, ref_thread, 1); - let index = ref_stack_pop(extra); - (*extra).wrapped_failure_pool.push(index); - (*extra).wrapped_failure_top += 1; - } - PreallocatedFailure::Reserved => (*extra).wrapped_failure_top += 1, - } + // Build `CallbackError` with traceback + let traceback = if ffi::lua_checkstack(state, ffi::LUA_TRACEBACK_STACK) != 0 { + ffi::luaL_traceback(state, state, ptr::null(), 0); + let traceback = util::to_string(state, -1); + ffi::lua_pop(state, 1); + traceback + } else { + "".to_string() + }; + let cause = Arc::new(err); + ptr::write( + wrapped_error, + WrappedFailure::Error(Error::CallbackError { traceback, cause }), + ); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + + ffi::lua_error(state) + } + Err(p) => { + let wrapped_panic = prealloc_failure.r#use(state, extra); + ptr::write(wrapped_panic, WrappedFailure::Panic(Some(p))); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state) } } +} + +/// An yieldable version of `callback_error_ext` +/// +/// Unlike ``callback_error_ext``, this method requires a c_int return +/// and not a generic R +pub(crate) unsafe fn callback_error_ext_yieldable( + state: *mut ffi::lua_State, + mut extra: *mut ExtraData, + wrap_error: bool, + f: F, + #[allow(unused_variables)] in_callback_with_continuation: bool, +) -> c_int +where + F: FnOnce(*mut ExtraData, c_int) -> Result, +{ + if extra.is_null() { + extra = ExtraData::get(state); + } + + let nargs = ffi::lua_gettop(state); // We cannot shadow Rust errors with Lua ones, so we need to reserve pre-allocated memory // to store a wrapped failure (error or panic) *before* we proceed. @@ -109,7 +187,128 @@ where })) { Ok(Ok(r)) => { // Return unused `WrappedFailure` to the pool + // + // In either case, we cannot use it in the yield case anyways due to the lua_pop call + // so drop it properly now while we can. prealloc_failure.release(state, extra); + + let raw = extra.as_ref().unwrap_unchecked().raw_lua(); + let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + + if let Some(values) = values { + // A note on Luau + // + // When using the yieldable continuations fflag (and in future when the fflag gets removed and + // yieldable continuations) becomes default, we must either pop the top of the + // stack on the state we are resuming or somehow store the number of + // args on top of stack pre-yield and then subtract in the resume in order to get predictable + // behaviour here. See https://github.com/luau-lang/luau/issues/1867 for more information + // + // In this case, popping is easier and leads to less bugs/more ergonomic API. + + if raw.state() == state { + // Edge case: main thread is being yielded + // + // We need to pop/clear stack early, then push args + ffi::lua_pop(state, -1); + } + + match values.push_into_stack_multi(raw) { + Ok(nargs) => { + // If not main thread, then clear and xmove to target thread + if raw.state() != state { + // luau preserves the stack making yieldable continuations ugly and leaky + // + // Even outside of luau, clearing the stack is probably desirable + ffi::lua_pop(state, -1); + if let Err(err) = check_stack(state, nargs) { + // Make a *new* preallocated failure, and then do normal error + let prealloc_failure = PreallocatedFailure::reserve(state, extra); + let wrapped_panic = prealloc_failure.r#use(state, extra); + ptr::write(wrapped_panic, WrappedFailure::Error(err)); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state); + } + ffi::lua_xmove(raw.state(), state, nargs); + } + + #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] + { + // Yield to a continuation. Unlike luau, we need to do this manually and on the + // fly using a yieldk call + if in_callback_with_continuation { + // On Lua 5.2, status and ctx are not present, so use 0 as status for + // compatibility + #[cfg(feature = "lua52")] + unsafe extern "C-unwind" fn cont_callback( + state: *mut ffi::lua_State, + ) -> c_int { + let upvalue = + get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available + // (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.1)(rawlua, nargs, 0), + None => Err(Error::CallbackDestructed), + } + }, + true, + ) + } + + // Lua 5.3/5.4 case + #[cfg(not(feature = "lua52"))] + unsafe extern "C-unwind" fn cont_callback( + state: *mut ffi::lua_State, + status: c_int, + _ctx: ffi::lua_KContext, + ) -> c_int { + let upvalue = + get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available + // (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.1)(rawlua, nargs, status), + None => Err(Error::CallbackDestructed), + } + }, + true, + ) + } + + return ffi::lua_yieldc(state, nargs, cont_callback); + } + } + + return ffi::lua_yield(state, nargs); + } + Err(err) => { + // Make a *new* preallocated failure, and then do normal wrap_error + let prealloc_failure = PreallocatedFailure::reserve(state, extra); + let wrapped_panic = prealloc_failure.r#use(state, extra); + ptr::write(wrapped_panic, WrappedFailure::Error(err)); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state); + } + } + } + r } Ok(Err(err)) => { diff --git a/src/thread.rs b/src/thread.rs index 13d95532..c98b947b 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -26,6 +26,25 @@ use { }, }; +/// Continuation thread status. Can either be Ok, Yielded (rare, but can happen) or Error +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum ContinuationStatus { + Ok, + Yielded, + Error, +} + +impl ContinuationStatus { + #[allow(dead_code)] + pub(crate) fn from_status(status: c_int) -> Self { + match status { + ffi::LUA_YIELD => Self::Yielded, + ffi::LUA_OK => Self::Ok, + _ => Self::Error, + } + } +} + /// Status of a Lua thread (coroutine). #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum ThreadStatus { @@ -215,6 +234,7 @@ impl Thread { let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int); #[cfg(feature = "luau")] let ret = ffi::lua_resumex(thread_state, state, nargs, &mut nresults as *mut c_int); + match ret { ffi::LUA_OK => Ok((ThreadStatusInner::Finished, nresults)), ffi::LUA_YIELD => Ok((ThreadStatusInner::Yielded(0), nresults)), diff --git a/src/types.rs b/src/types.rs index 2589ea6e..45806247 100644 --- a/src/types.rs +++ b/src/types.rs @@ -39,6 +39,11 @@ pub(crate) type Callback = Box Result + Send + #[cfg(not(feature = "send"))] pub(crate) type Callback = Box Result + 'static>; +#[cfg(all(feature = "send", not(feature = "lua51"), not(feature = "luajit")))] +pub(crate) type Continuation = Box Result + Send + 'static>; + +#[cfg(all(not(feature = "send"), not(feature = "lua51"), not(feature = "luajit")))] +pub(crate) type Continuation = Box Result + 'static>; pub(crate) type ScopedCallback<'s> = Box Result + 's>; @@ -48,6 +53,8 @@ pub(crate) struct Upvalue { } pub(crate) type CallbackUpvalue = Upvalue>; +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +pub(crate) type ContinuationUpvalue = Upvalue>; #[cfg(all(feature = "async", feature = "send"))] pub(crate) type AsyncCallback = diff --git a/src/util/mod.rs b/src/util/mod.rs index f5fbae52..5dd6b19c 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -75,6 +75,7 @@ impl Drop for StackGuard { unsafe { let top = ffi::lua_gettop(self.state); if top < self.top { + println!("top={}, self.top={}", top, self.top); mlua_panic!("{} too many stack values popped", self.top - top) } if top > self.top { diff --git a/src/util/types.rs b/src/util/types.rs index 8bc9d8b2..8627042f 100644 --- a/src/util/types.rs +++ b/src/util/types.rs @@ -3,6 +3,9 @@ use std::os::raw::c_void; use crate::types::{Callback, CallbackUpvalue}; +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +use crate::types::ContinuationUpvalue; + #[cfg(feature = "async")] use crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue}; @@ -34,6 +37,15 @@ impl TypeKey for CallbackUpvalue { } } +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +impl TypeKey for ContinuationUpvalue { + #[inline(always)] + fn type_key() -> *const c_void { + static CONTINUATION_UPVALUE_TYPE_KEY: u8 = 0; + &CONTINUATION_UPVALUE_TYPE_KEY as *const u8 as *const c_void + } +} + #[cfg(not(feature = "luau"))] impl TypeKey for crate::types::HookCallback { #[inline(always)] diff --git a/tests/thread.rs b/tests/thread.rs index 4cb6ab10..9c8c1b65 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -252,3 +252,268 @@ fn test_thread_resume_error() -> Result<()> { Ok(()) } + +#[test] +fn test_thread_yield_args() -> Result<()> { + let lua = Lua::new(); + let always_yield = lua.create_function(|lua, ()| lua.yield_with((42, "69420".to_string(), 45.6)))?; + + let thread = lua.create_thread(always_yield)?; + assert_eq!( + thread.resume::<(i32, String, f32)>(())?, + (42, String::from("69420"), 45.6) + ); + + Ok(()) +} + +#[test] +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +fn test_continuation() { + let lua = Lua::new(); + // No yielding continuation fflag test + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with(a), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 41); + + // empty yield args test + let cont_func = lua + .create_function_with_continuation( + |lua, _: ()| lua.yield_with(()), + |_lua, _status, mv: mlua::MultiValue| Ok(mv.len()), + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res - 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + assert!(v.is_empty()); + let v = th.resume::(v).expect("Failed to load continuation"); + assert_eq!(v, -1); + + // Yielding continuation test (only supported on luau) + #[cfg(feature = "luau")] + { + mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); + } + + let cont_func = lua + .create_function_with_continuation( + |_lua, a: u64| Ok(a + 1), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 2) + }, + ) + .expect("Failed to create cont_func"); + + // Ensure normal calls work still + assert_eq!( + lua.load("local cont_func = ...\nreturn cont_func(1)") + .call::(cont_func) + .expect("Failed to call cont_func"), + 2 + ); + + // basic yield test before we go any further + let always_yield = lua + .create_function(|lua, ()| lua.yield_with((42, "69420".to_string(), 45.6))) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + assert_eq!( + thread.resume::<(i32, String, f32)>(()).unwrap(), + (42, String::from("69420"), 45.6) + ); + + // Trigger the continuation + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with(a), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 41); + + let always_yield = lua + .create_function_with_continuation( + |lua, ()| lua.yield_with((42, "69420".to_string(), 45.6)), + |_lua, _, mv: mlua::MultiValue| { + println!("Reached second continuation"); + if mv.is_empty() { + return Ok(mv); + } + Err(mlua::Error::external(format!("a{}", mv.len()))) + }, + ) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + let mv = thread.resume::(()).unwrap(); + assert!(thread + .resume::(mv) + .unwrap_err() + .to_string() + .starts_with("a3")); + + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with((a + 1, 1)), + |lua, status, args: mlua::MultiValue| { + println!("Reached cont recursive/multiple: {:?}", args); + + if args.len() == 5 { + if cfg!(any(feature = "luau", feature = "lua52")) { + assert_eq!(status, mlua::ContinuationStatus::Ok); + } else { + assert_eq!(status, mlua::ContinuationStatus::Yielded); + } + return Ok(6_i32); + } + + lua.yield_with((args.len() + 1, args))?; // thread state becomes LEN, LEN-1... 1 + Ok(1_i32) // this will be ignored + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + println!("v={:?}", v); + + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + + // (2, 1) followed by () + assert_eq!(v.len(), 2 + 3); + + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 7); + + // test panics + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with(a), + |_lua, _status, _a: u64| { + panic!("Reached continuation which should panic!"); + #[allow(unreachable_code)] + Ok(()) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local ok, res = pcall(cont_func, 1) + assert(not ok) + return tostring(res) + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + + let v = th.resume::(v).expect("Failed to load continuation"); + assert!(v.contains("Reached continuation which should panic!")); +}