From c5f36c9412e6eadc70d36badecd41bbf013da5a8 Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Sun, 22 Oct 2023 00:51:15 +0100 Subject: [PATCH] Add package module to Luau Introduce module loaders Support loading binary modules --- Cargo.toml | 5 +- mlua-sys/Cargo.toml | 2 +- mlua-sys/build/main_inner.rs | 4 +- mlua-sys/src/luau/lua.rs | 12 +- src/lua.rs | 126 ++++++------- src/luau.rs | 295 +++++++++++++++++++++++++----- src/memory.rs | 53 +++--- src/util/mod.rs | 27 +-- tests/luau.rs | 6 +- tests/module/Cargo.toml | 1 + tests/module/loader/Cargo.toml | 1 + tests/module/loader/tests/load.rs | 2 +- 12 files changed, 359 insertions(+), 175 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2c9664af..d15f58b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ lua52 = ["ffi/lua52"] lua51 = ["ffi/lua51"] luajit = ["ffi/luajit"] luajit52 = ["luajit", "ffi/luajit52"] -luau = ["ffi/luau"] +luau = ["ffi/luau", "libloading"] luau-jit = ["luau", "ffi/luau-codegen"] luau-vector4 = ["luau", "ffi/luau-vector4"] vendored = ["ffi/vendored"] @@ -57,6 +57,9 @@ parking_lot = { version = "0.12", optional = true } ffi = { package = "mlua-sys", version = "0.3.2", path = "mlua-sys" } +[target.'cfg(unix)'.dependencies] +libloading = { version = "0.8", optional = true } + [dev-dependencies] rustyline = "12.0" criterion = { version = "0.5", features = ["async_tokio"] } diff --git a/mlua-sys/Cargo.toml b/mlua-sys/Cargo.toml index 8cfa28da..0da8fcca 100644 --- a/mlua-sys/Cargo.toml +++ b/mlua-sys/Cargo.toml @@ -40,4 +40,4 @@ cfg-if = "1.0" pkg-config = "0.3.17" lua-src = { version = ">= 546.0.0, < 546.1.0", optional = true } luajit-src = { version = ">= 210.5.0, < 210.6.0", optional = true } -luau0-src = { version = "0.7.0", optional = true } +luau0-src = { version = "0.7.7", optional = true } diff --git a/mlua-sys/build/main_inner.rs b/mlua-sys/build/main_inner.rs index 9b8b15d5..668f40b3 100644 --- a/mlua-sys/build/main_inner.rs +++ b/mlua-sys/build/main_inner.rs @@ -9,8 +9,8 @@ cfg_if::cfg_if! { } fn main() { - #[cfg(all(feature = "luau", feature = "module"))] - compile_error!("Luau does not support `module` mode"); + #[cfg(all(feature = "luau", feature = "module", windows))] + compile_error!("Luau does not support `module` mode on Windows"); #[cfg(all(feature = "module", feature = "vendored"))] compile_error!("`vendored` and `module` features are mutually exclusive"); diff --git a/mlua-sys/src/luau/lua.rs b/mlua-sys/src/luau/lua.rs index 51168c03..996f62ee 100644 --- a/mlua-sys/src/luau/lua.rs +++ b/mlua-sys/src/luau/lua.rs @@ -2,7 +2,7 @@ use std::marker::{PhantomData, PhantomPinned}; use std::os::raw::{c_char, c_double, c_float, c_int, c_uint, c_void}; -use std::ptr; +use std::{mem, ptr}; // Option for multiple returns in 'lua_pcall' and 'lua_call' pub const LUA_MULTRET: c_int = -1; @@ -278,6 +278,7 @@ extern "C-unwind" { pub fn lua_getuserdatadtor(L: *mut lua_State, tag: c_int) -> Option; pub fn lua_clonefunction(L: *mut lua_State, idx: c_int); pub fn lua_cleartable(L: *mut lua_State, idx: c_int); + pub fn lua_getallocf(L: *mut lua_State, ud: *mut *mut c_void) -> lua_Alloc; } // @@ -325,6 +326,15 @@ pub unsafe fn lua_newuserdata(L: *mut lua_State, sz: usize) -> *mut c_void { lua_newuserdatatagged(L, sz, 0) } +#[inline(always)] +pub unsafe fn lua_newuserdata_t(L: *mut lua_State) -> *mut T { + unsafe extern "C-unwind" fn destructor(ud: *mut c_void) { + ptr::drop_in_place(ud as *mut T); + } + + lua_newuserdatadtor(L, mem::size_of::(), destructor::) as *mut T +} + // TODO: lua_strlen #[inline(always)] diff --git a/src/lua.rs b/src/lua.rs index 11765cee..30129595 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -1,17 +1,15 @@ use std::any::TypeId; use std::cell::{RefCell, UnsafeCell}; use std::ffi::{CStr, CString}; -use std::fmt; use std::marker::PhantomData; use std::mem::MaybeUninit; use std::ops::Deref; use std::os::raw::{c_char, c_int, c_void}; use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe, Location}; -use std::ptr::NonNull; use std::result::Result as StdResult; use std::sync::atomic::{AtomicPtr, Ordering}; use std::sync::{Arc, Mutex}; -use std::{mem, ptr, str}; +use std::{fmt, mem, ptr, str}; use rustc_hash::FxHashMap; @@ -60,6 +58,7 @@ use { crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue}, futures_util::future::{self, Future}, futures_util::task::{noop_waker_ref, Context, Poll, Waker}, + std::ptr::NonNull, }; #[cfg(feature = "serialize")] @@ -94,7 +93,6 @@ pub(crate) struct ExtraData { safe: bool, libs: StdLib, - mem_state: Option>, #[cfg(feature = "module")] skip_memory_check: bool, @@ -244,11 +242,14 @@ impl Drop for Lua { impl Drop for LuaInner { fn drop(&mut self) { unsafe { - #[cfg(feature = "luau")] - { - (*ffi::lua_callbacks(self.state())).userdata = ptr::null_mut(); - } + let mem_state = MemoryState::get(self.main_state); + ffi::lua_close(self.main_state); + + // Deallocate MemoryState if it was created by us + if !mem_state.is_null() { + drop(Box::from_raw(mem_state)); + } } } } @@ -261,9 +262,6 @@ impl Drop for ExtraData { } *mlua_expect!(self.registry_unref_list.lock(), "unref list poisoned") = None; - if let Some(mem_state) = self.mem_state { - drop(unsafe { Box::from_raw(mem_state.as_ptr()) }); - } } } @@ -359,23 +357,22 @@ impl Lua { /// /// [`StdLib`]: crate::StdLib pub unsafe fn unsafe_new_with(libs: StdLib, options: LuaOptions) -> Lua { + // Workaround to avoid stripping a few unused Lua symbols that could be imported + // by C modules in unsafe mode + let mut _symbols: Vec<*const extern "C-unwind" fn()> = + vec![ffi::lua_isuserdata as _, ffi::lua_tocfunction as _]; + #[cfg(not(feature = "luau"))] + _symbols.extend_from_slice(&[ + ffi::lua_atpanic as _, + ffi::luaL_loadstring as _, + ffi::luaL_openlibs as _, + ]); + #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] { - // Workaround to avoid stripping a few unused Lua symbols that could be imported - // by C modules in unsafe mode - let mut _symbols: Vec<*const extern "C-unwind" fn()> = vec![ - ffi::lua_atpanic as _, - ffi::lua_isuserdata as _, - ffi::lua_tocfunction as _, - ffi::luaL_loadstring as _, - ffi::luaL_openlibs as _, - ]; - #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] - { - _symbols.push(ffi::lua_getglobal as _); - _symbols.push(ffi::lua_setglobal as _); - _symbols.push(ffi::luaL_setfuncs as _); - } + _symbols.push(ffi::lua_getglobal as _); + _symbols.push(ffi::lua_setglobal as _); + _symbols.push(ffi::luaL_setfuncs as _); } Self::inner_new(libs, options) @@ -383,12 +380,11 @@ impl Lua { /// Creates a new Lua state with required `libs` and `options` unsafe fn inner_new(libs: StdLib, options: LuaOptions) -> Lua { - let mut mem_state: *mut MemoryState = Box::into_raw(Box::default()); + let mem_state: *mut MemoryState = Box::into_raw(Box::default()); let mut state = ffi::lua_newstate(ALLOCATOR, mem_state as *mut c_void); // If state is null then switch to Lua internal allocator if state.is_null() { drop(Box::from_raw(mem_state)); - mem_state = ptr::null_mut(); state = ffi::luaL_newstate(); } assert!(!state.is_null(), "Failed to instantiate Lua VM"); @@ -404,7 +400,6 @@ impl Lua { let lua = Lua::init_from_ptr(state); let extra = lua.extra.get(); - (*extra).mem_state = NonNull::new(mem_state); mlua_expect!( load_from_std_lib(state, libs), @@ -440,7 +435,7 @@ impl Lua { } #[cfg(feature = "luau")] - mlua_expect!(lua.prepare_luau_state(), "Error preparing Luau state"); + mlua_expect!(lua.prepare_luau_state(), "Error configuring Luau"); lua } @@ -514,7 +509,6 @@ impl Lua { app_data: AppData::default(), safe: false, libs: StdLib::NONE, - mem_state: None, #[cfg(feature = "module")] skip_memory_check: false, ref_thread, @@ -547,14 +541,8 @@ impl Lua { // Store it in the registry mlua_expect!( - (|state| { - push_gc_userdata(state, Arc::clone(&extra), true)?; - protect_lua!(state, 1, 0, fn(state) { - let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void; - ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, extra_key); - }) - })(main_state), - "Error while storing extra data", + set_extra_data(main_state, &extra), + "Error while storing extra data" ); // Register `DestructedUserdata` type @@ -572,13 +560,6 @@ impl Lua { ); assert_stack(main_state, ffi::LUA_MINSTACK); - // Set Luau callbacks userdata to extra data - // We can use global callbacks userdata since we don't allow C modules in Luau - #[cfg(feature = "luau")] - { - (*ffi::lua_callbacks(main_state)).userdata = extra.get() as *mut c_void; - } - let inner = Arc::new(LuaInner { state: AtomicPtr::new(state), main_state, @@ -1098,9 +1079,9 @@ impl Lua { /// Returns the amount of memory (in bytes) currently used inside this Lua state. pub fn used_memory(&self) -> usize { unsafe { - match (*self.extra.get()).mem_state.map(|x| x.as_ref()) { - Some(mem_state) => mem_state.used_memory(), - None => { + match MemoryState::get(self.main_state) { + mem_state if !mem_state.is_null() => (*mem_state).used_memory(), + _ => { // Get data from the Lua GC let used_kbytes = ffi::lua_gc(self.main_state, ffi::LUA_GCCOUNT, 0); let used_kbytes_rem = ffi::lua_gc(self.main_state, ffi::LUA_GCCOUNTB, 0); @@ -1119,9 +1100,9 @@ impl Lua { /// Does not work in module mode where Lua state is managed externally. pub fn set_memory_limit(&self, limit: usize) -> Result { unsafe { - match (*self.extra.get()).mem_state.map(|mut x| x.as_mut()) { - Some(mem_state) => Ok(mem_state.set_memory_limit(limit)), - None => Err(Error::MemoryLimitNotAvailable), + match MemoryState::get(self.main_state) { + mem_state if !mem_state.is_null() => Ok((*mem_state).set_memory_limit(limit)), + _ => Err(Error::MemoryLimitNotAvailable), } } } @@ -3169,9 +3150,9 @@ impl Lua { #[inline] pub(crate) unsafe fn unlikely_memory_error(&self) -> bool { // MemoryInfo is empty in module mode so we cannot predict memory limits - (*self.extra.get()) - .mem_state - .map(|x| x.as_ref().memory_limit() == 0) + MemoryState::get(self.main_state) + .as_ref() + .map(|x| x.memory_limit() == 0) .unwrap_or_else(|| { // Alternatively, check the special flag (only for module mode) #[cfg(feature = "module")] @@ -3223,14 +3204,6 @@ impl LuaInner { } } -impl ExtraData { - #[cfg(feature = "luau")] - #[inline] - pub(crate) fn mem_state(&self) -> NonNull { - self.mem_state.unwrap() - } -} - struct StateGuard<'a>(&'a LuaInner, *mut ffi::lua_State); impl<'a> StateGuard<'a> { @@ -3246,13 +3219,13 @@ impl<'a> Drop for StateGuard<'a> { } } -#[cfg(feature = "luau")] unsafe fn extra_data(state: *mut ffi::lua_State) -> *mut ExtraData { - (*ffi::lua_callbacks(state)).userdata as *mut ExtraData -} + #[cfg(feature = "luau")] + if cfg!(not(feature = "module")) { + // In the main app we can use `lua_callbacks` to access ExtraData + return (*ffi::lua_callbacks(state)).userdata as *mut _; + } -#[cfg(not(feature = "luau"))] -unsafe fn extra_data(state: *mut ffi::lua_State) -> *mut ExtraData { let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void; if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, extra_key) != ffi::LUA_TUSERDATA { // `ExtraData` can be null only when Lua state is foreign. @@ -3265,6 +3238,23 @@ unsafe fn extra_data(state: *mut ffi::lua_State) -> *mut ExtraData { (*extra_ptr).get() } +unsafe fn set_extra_data( + state: *mut ffi::lua_State, + extra: &Arc>, +) -> Result<()> { + #[cfg(feature = "luau")] + if cfg!(not(feature = "module")) { + (*ffi::lua_callbacks(state)).userdata = extra.get() as *mut _; + return Ok(()); + } + + push_gc_userdata(state, Arc::clone(extra), true)?; + protect_lua!(state, 1, 0, fn(state) { + let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void; + ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, extra_key); + }) +} + // Creates required entries in the metatable cache (see `util::METATABLE_CACHE`) pub(crate) fn init_metatable_cache(cache: &mut FxHashMap) { cache.insert(TypeId::of::>>(), 0); diff --git a/src/luau.rs b/src/luau.rs index e17296e9..6f698e79 100644 --- a/src/luau.rs +++ b/src/luau.rs @@ -1,15 +1,54 @@ use std::ffi::CStr; +use std::fmt::Write; use std::os::raw::{c_float, c_int}; +use std::path::{PathBuf, MAIN_SEPARATOR_STR}; use std::string::String as StdString; +use std::{env, fs}; + +use rustc_hash::FxHashMap; use crate::chunk::ChunkMode; -use crate::error::{Error, Result}; +use crate::error::Result; use crate::lua::Lua; use crate::table::Table; -use crate::util::{check_stack, StackGuard}; -use crate::value::Value; +use crate::types::RegistryKey; +use crate::value::{IntoLua, Value}; + +#[cfg(unix)] +use libloading::Library; + +// Since Luau has some missing standard functions, we re-implement them here + +#[cfg(unix)] +const TARGET_MLUA_LUAU_ABI_VERSION: u32 = 2; + +#[cfg(all(unix, feature = "module"))] +#[no_mangle] +#[used] +pub static MLUA_LUAU_ABI_VERSION: u32 = TARGET_MLUA_LUAU_ABI_VERSION; + +// We keep reference to the `package` table in registry under this key +struct PackageKey(RegistryKey); + +// We keep reference to the loaded dylibs in application data +#[cfg(unix)] +struct LoadedDylibs(FxHashMap); -// Since Luau has some missing standard function, we re-implement them here +#[cfg(unix)] +impl std::ops::Deref for LoadedDylibs { + type Target = FxHashMap; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[cfg(unix)] +impl std::ops::DerefMut for LoadedDylibs { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} impl Lua { pub(crate) unsafe fn prepare_luau_state(&self) -> Result<()> { @@ -19,7 +58,8 @@ impl Lua { "collectgarbage", self.create_c_function(lua_collectgarbage)?, )?; - globals.raw_set("require", self.create_function(lua_require)?)?; + globals.raw_set("require", self.create_c_function(lua_require)?)?; + globals.raw_set("package", create_package_table(self)?)?; globals.raw_set("vector", self.create_c_function(lua_vector)?)?; // Set `_VERSION` global to include version number @@ -69,56 +109,56 @@ unsafe extern "C-unwind" fn lua_collectgarbage(state: *mut ffi::lua_State) -> c_ } } -fn lua_require(lua: &Lua, name: Option) -> Result { - let name = name.ok_or_else(|| Error::runtime("invalid module name"))?; - - // Find module in the cache - let state = lua.state(); - let loaded = unsafe { - let _sg = StackGuard::new(state); - check_stack(state, 2)?; - protect_lua!(state, 0, 1, fn(state) { - ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED")); - })?; - Table(lua.pop_ref()) - }; - if let Some(v) = loaded.raw_get(name.clone())? { - return Ok(v); +unsafe extern "C-unwind" fn lua_require(state: *mut ffi::lua_State) -> c_int { + ffi::lua_settop(state, 1); + let name = ffi::luaL_checkstring(state, 1); + ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED")); // _LOADED is at index 2 + if ffi::lua_rawgetfield(state, 2, name) != ffi::LUA_TNIL { + return 1; // module is already loaded } + ffi::lua_pop(state, 1); // remove nil - // Load file from filesystem - let mut search_path = std::env::var("LUAU_PATH").unwrap_or_default(); - if search_path.is_empty() { - search_path = "?.luau;?.lua".into(); + // load the module + let err_buf = ffi::lua_newuserdata_t::(state); + ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADERS")); // _LOADERS is at index 3 + for i in 1.. { + if ffi::lua_rawgeti(state, -1, i) == ffi::LUA_TNIL { + // no more loaders? + if (*err_buf).is_empty() { + ffi::luaL_error(state, cstr!("module '%s' not found"), name); + } else { + let bytes = (*err_buf).as_bytes(); + let extra = ffi::lua_pushlstring(state, bytes.as_ptr() as *const _, bytes.len()); + ffi::luaL_error(state, cstr!("module '%s' not found:%s"), name, extra); + } + } + ffi::lua_pushvalue(state, 1); // name arg + ffi::lua_call(state, 1, 2); // call loader + match ffi::lua_type(state, -2) { + ffi::LUA_TFUNCTION => break, // loader found + ffi::LUA_TSTRING => { + // error message + let msg = ffi::lua_tostring(state, -2); + let msg = CStr::from_ptr(msg).to_string_lossy(); + _ = write!(&mut *err_buf, "\n\t{msg}"); + } + _ => {} + } + ffi::lua_pop(state, 2); // remove both results } + ffi::lua_pushvalue(state, 1); // name is 1st argument to module loader + ffi::lua_rotate(state, -2, 1); // loader data <-> name - let (mut source, mut source_name) = (None, String::new()); - for path in search_path.split(';') { - let file_path = path.replacen('?', &name, 1); - if let Ok(buf) = std::fs::read(&file_path) { - source = Some(buf); - source_name = file_path; - break; - } + // stack: ...; loader function; module name; loader data + ffi::lua_call(state, 2, 1); + // stack: ...; result from loader function + if ffi::lua_isnil(state, -1) != 0 { + ffi::lua_pop(state, 1); + ffi::lua_pushboolean(state, 1); // use true as result } - let source = source.ok_or_else(|| Error::runtime(format!("cannot find '{name}'")))?; - - let value = lua - .load(&source) - .set_name(&format!("={source_name}")) - .set_mode(ChunkMode::Text) - .call::<_, Value>(())?; - - // Save in the cache - loaded.raw_set( - name, - match value.clone() { - Value::Nil => Value::Boolean(true), - v => v, - }, - )?; - - Ok(value) + ffi::lua_pushvalue(state, -1); // make copy of entrypoint result + ffi::lua_setfield(state, 2, name); /* _LOADED[name] = returned value */ + 1 } // Luau vector datatype constructor @@ -135,3 +175,158 @@ unsafe extern "C-unwind" fn lua_vector(state: *mut ffi::lua_State) -> c_int { ffi::lua_pushvector(state, x, y, z, w); 1 } + +// +// Module loaders +// + +/// Tries to load a lua (text) file +fn lua_loader(lua: &Lua, modname: StdString) -> Result { + let package = { + let key = lua.app_data_ref::().unwrap(); + lua.registry_value::(&key.0) + }?; + let search_path = package.get::<_, StdString>("path").unwrap_or_default(); + + if let Some(file_path) = package_searchpath(&modname, &search_path, false) { + match fs::read(&file_path) { + Ok(buf) => { + return lua + .load(&buf) + .set_name(&format!("={}", file_path.display())) + .set_mode(ChunkMode::Text) + .into_function() + .map(Value::Function); + } + Err(err) => { + return format!("cannot open '{}': {err}", file_path.display()).into_lua(lua); + } + } + } + + Ok(Value::Nil) +} + +/// Tries to load a dynamic library +#[cfg(unix)] +fn dylib_loader(lua: &Lua, modname: StdString) -> Result { + let package = { + let key = lua.app_data_ref::().unwrap(); + lua.registry_value::
(&key.0) + }?; + let search_cpath = package.get::<_, StdString>("cpath").unwrap_or_default(); + + let find_symbol = |lib: &Library| unsafe { + if let Ok(entry) = lib.get::(format!("luaopen_{modname}\0").as_bytes()) + { + return lua.create_c_function(*entry).map(Value::Function); + } + // Try all in one mode + if let Ok(entry) = lib.get::( + format!("luaopen_{}\0", modname.replace('.', "_")).as_bytes(), + ) { + return lua.create_c_function(*entry).map(Value::Function); + } + "cannot find module entrypoint".into_lua(lua) + }; + + if let Some(file_path) = package_searchpath(&modname, &search_cpath, true) { + let file_path = file_path.canonicalize()?; + // Load the library and check for symbol + unsafe { + // Check if it's already loaded + if let Some(lib) = lua.app_data_ref::().unwrap().get(&file_path) { + return find_symbol(lib); + } + if let Ok(lib) = Library::new(&file_path) { + // Check version + let mod_version = lib.get::<*const u32>(b"MLUA_LUAU_ABI_VERSION"); + let mod_version = mod_version.map(|v| **v).unwrap_or_default(); + if mod_version != TARGET_MLUA_LUAU_ABI_VERSION { + let err = format!("wrong module ABI version (expected {TARGET_MLUA_LUAU_ABI_VERSION}, got {mod_version})"); + return err.into_lua(lua); + } + let symbol = find_symbol(&lib); + lua.app_data_mut::() + .unwrap() + .insert(file_path, lib); + return symbol; + } + } + } + + Ok(Value::Nil) +} + +// +// package module +// + +fn create_package_table(lua: &Lua) -> Result
{ + // Create the package table and store it in app_data for later use (bypassing globals lookup) + let package = lua.create_table()?; + lua.set_app_data(PackageKey(lua.create_registry_value(package.clone())?)); + + // set package.path + let mut search_path = env::var("LUAU_PATH") + .or_else(|_| env::var("LUA_PATH")) + .unwrap_or_default(); + if search_path.is_empty() { + search_path = "?.luau;?.lua".to_string(); + } + package.raw_set("path", search_path)?; + + // set package.cpath + #[cfg(unix)] + { + let mut search_cpath = env::var("LUAU_CPATH") + .or_else(|_| env::var("LUA_CPATH")) + .unwrap_or_default(); + if search_cpath.is_empty() { + if cfg!(any(target_os = "macos", target_os = "ios")) { + search_cpath = "?.dylib".to_string(); + } else { + search_cpath = "?.so".to_string(); + } + } + package.raw_set("cpath", search_cpath)?; + } + + // set package.loaded (table with a list of loaded modules) + let loaded = lua.create_table()?; + package.raw_set("loaded", loaded.clone())?; + lua.set_named_registry_value("_LOADED", loaded)?; + + // set package.loaders + let loaders = lua.create_sequence_from([lua.create_function(lua_loader)?])?; + package.raw_set("loaders", loaders.clone())?; + #[cfg(unix)] + { + loaders.push(lua.create_function(dylib_loader)?)?; + let loaded_dylibs = LoadedDylibs(FxHashMap::default()); + lua.set_app_data(loaded_dylibs); + } + lua.set_named_registry_value("_LOADERS", loaders)?; + + Ok(package) +} + +/// Searches for the given `name`` in the given `path`. +/// +/// `path` is a string containing a sequence of templates separated by semicolons. +fn package_searchpath(name: &str, search_path: &str, try_prefix: bool) -> Option { + let mut names = vec![name.replace('.', MAIN_SEPARATOR_STR)]; + if try_prefix && name.contains('.') { + let prefix = name.split_once('.').map(|(prefix, _)| prefix).unwrap(); + names.push(prefix.to_string()); + } + for path in search_path.split(';') { + for name in &names { + let file_path = PathBuf::from(path.replace('?', name)); + if let Ok(true) = fs::metadata(&file_path).map(|m| m.is_file()) { + return Some(file_path); + } + } + } + None +} diff --git a/src/memory.rs b/src/memory.rs index e199a759..672a8647 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -2,11 +2,9 @@ use std::alloc::{self, Layout}; use std::os::raw::c_void; use std::ptr; -#[cfg(feature = "luau")] -use crate::lua::ExtraData; - pub(crate) static ALLOCATOR: ffi::lua_Alloc = allocator; +#[repr(C)] #[derive(Default)] pub(crate) struct MemoryState { used_memory: isize, @@ -20,6 +18,21 @@ pub(crate) struct MemoryState { } impl MemoryState { + #[inline] + pub(crate) unsafe fn get(state: *mut ffi::lua_State) -> *mut Self { + let mut mem_state = ptr::null_mut(); + #[cfg(feature = "luau")] + { + ffi::lua_getallocf(state, &mut mem_state); + mlua_assert!(!mem_state.is_null(), "Luau state has no allocator userdata"); + } + #[cfg(not(feature = "luau"))] + if ffi::lua_getallocf(state, &mut mem_state) != ALLOCATOR { + mem_state = ptr::null_mut(); + } + mem_state as *mut MemoryState + } + #[inline] pub(crate) fn used_memory(&self) -> usize { self.used_memory as usize @@ -37,36 +50,21 @@ impl MemoryState { prev_limit as usize } - // This function is used primarily for calling `lua_pushcfunction` in lua5.1/jit + // This function is used primarily for calling `lua_pushcfunction` in lua5.1/jit/luau // to bypass the memory limit (if set). - #[cfg(any(feature = "lua51", feature = "luajit"))] + #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] #[inline] pub(crate) unsafe fn relax_limit_with(state: *mut ffi::lua_State, f: impl FnOnce()) { - let mut mem_state: *mut c_void = ptr::null_mut(); - if ffi::lua_getallocf(state, &mut mem_state) == ALLOCATOR { - (*(mem_state as *mut MemoryState)).ignore_limit = true; + let mem_state = Self::get(state); + if !mem_state.is_null() { + (*mem_state).ignore_limit = true; f(); - (*(mem_state as *mut MemoryState)).ignore_limit = false; + (*mem_state).ignore_limit = false; } else { f(); } } - // Same as the above but for Luau - // It does not have `lua_getallocf` function, so instead we use `lua_callbacks` - #[cfg(feature = "luau")] - #[inline] - pub(crate) unsafe fn relax_limit_with(state: *mut ffi::lua_State, f: impl FnOnce()) { - let extra = (*ffi::lua_callbacks(state)).userdata as *mut ExtraData; - if extra.is_null() { - return f(); - } - let mem_state = (*extra).mem_state(); - (*mem_state.as_ptr()).ignore_limit = true; - f(); - (*mem_state.as_ptr()).ignore_limit = false; - } - // Does nothing apart from calling `f()`, we don't need to bypass any limits #[cfg(any(feature = "lua52", feature = "lua53", feature = "lua54"))] #[inline] @@ -76,12 +74,9 @@ impl MemoryState { // Returns `true` if the memory limit was reached on the last memory operation #[cfg(feature = "luau")] + #[inline] pub(crate) unsafe fn limit_reached(state: *mut ffi::lua_State) -> bool { - let extra = (*ffi::lua_callbacks(state)).userdata as *mut ExtraData; - if extra.is_null() { - return false; - } - (*(*extra).mem_state().as_ptr()).limit_reached + (*Self::get(state)).limit_reached } } diff --git a/src/util/mod.rs b/src/util/mod.rs index 1351fbfc..b196e684 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -2,11 +2,11 @@ use std::any::{Any, TypeId}; use std::borrow::Cow; use std::ffi::CStr; use std::fmt::Write; -use std::mem::MaybeUninit; +use std::mem::{self, MaybeUninit}; use std::os::raw::{c_char, c_int, c_void}; use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; use std::sync::Arc; -use std::{mem, ptr, slice, str}; +use std::{ptr, slice, str}; use once_cell::sync::Lazy; use rustc_hash::FxHashMap; @@ -300,20 +300,13 @@ pub unsafe fn push_userdata(state: *mut ffi::lua_State, t: T, protect: bool) #[cfg(feature = "luau")] #[inline] pub unsafe fn push_userdata(state: *mut ffi::lua_State, t: T, protect: bool) -> Result<()> { - unsafe extern "C-unwind" fn destructor(ud: *mut c_void) { - ptr::drop_in_place(ud as *mut T); - } - - let size = mem::size_of::(); + let _ = mem::size_of::(); // to avoid compilation warning about unused `mem` let ud = if protect { - protect_lua!(state, 0, 1, |state| { - ffi::lua_newuserdatadtor(state, size, destructor::) as *mut T - })? + protect_lua!(state, 0, 1, |state| { ffi::lua_newuserdata_t::(state) })? } else { - ffi::lua_newuserdatadtor(state, size, destructor::) as *mut T + ffi::lua_newuserdata_t::(state) }; ptr::write(ud, t); - Ok(()) } @@ -1009,16 +1002,10 @@ pub(crate) enum WrappedFailure { impl WrappedFailure { pub(crate) unsafe fn new_userdata(state: *mut ffi::lua_State) -> *mut Self { - let size = mem::size_of::(); #[cfg(feature = "luau")] - let ud = { - unsafe extern "C-unwind" fn destructor(p: *mut c_void) { - ptr::drop_in_place(p as *mut WrappedFailure); - } - ffi::lua_newuserdatadtor(state, size, destructor) as *mut Self - }; + let ud = ffi::lua_newuserdata_t::(state); #[cfg(not(feature = "luau"))] - let ud = ffi::lua_newuserdata(state, size) as *mut Self; + let ud = ffi::lua_newuserdata(state, mem::size_of::()) as *mut Self; ptr::write(ud, WrappedFailure::None); ud } diff --git a/tests/luau.rs b/tests/luau.rs index 93a34114..793f78be 100644 --- a/tests/luau.rs +++ b/tests/luau.rs @@ -1,6 +1,5 @@ #![cfg(feature = "luau")] -use std::env; use std::fmt::Debug; use std::fs; use std::panic::{catch_unwind, AssertUnwindSafe}; @@ -37,7 +36,10 @@ fn test_require() -> Result<()> { "#, )?; - env::set_var("LUAU_PATH", temp_dir.path().join("?.luau")); + lua.globals() + .get::<_, Table>("package")? + .set("path", temp_dir.path().join("?.luau").to_string_lossy())?; + lua.load( r#" local module = require("module") diff --git a/tests/module/Cargo.toml b/tests/module/Cargo.toml index c2e0da8d..f107ad73 100644 --- a/tests/module/Cargo.toml +++ b/tests/module/Cargo.toml @@ -18,6 +18,7 @@ lua53 = ["mlua/lua53"] lua52 = ["mlua/lua52"] lua51 = ["mlua/lua51"] luajit = ["mlua/luajit"] +luau = ["mlua/luau"] [dependencies] mlua = { path = "../..", features = ["module"] } diff --git a/tests/module/loader/Cargo.toml b/tests/module/loader/Cargo.toml index b51f002c..64b196ff 100644 --- a/tests/module/loader/Cargo.toml +++ b/tests/module/loader/Cargo.toml @@ -10,6 +10,7 @@ lua53 = ["mlua/lua53"] lua52 = ["mlua/lua52"] lua51 = ["mlua/lua51"] luajit = ["mlua/luajit"] +luau = ["mlua/luau"] vendored = ["mlua/vendored"] [dependencies] diff --git a/tests/module/loader/tests/load.rs b/tests/module/loader/tests/load.rs index d06ece4f..25f85ab0 100644 --- a/tests/module/loader/tests/load.rs +++ b/tests/module/loader/tests/load.rs @@ -4,7 +4,7 @@ use std::path::PathBuf; use mlua::{Lua, Result}; #[test] -fn test_module() -> Result<()> { +fn test_module_simple() -> Result<()> { let lua = make_lua()?; lua.load( r#"