diff --git a/Cargo.toml b/Cargo.toml index 5a065523..9ee1587c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,9 @@ module = ["dep:mlua_derive", "ffi/module"] async = ["dep:futures-util"] send = [] serialize = ["dep:serde", "dep:erased-serde", "dep:serde-value"] +uuid = ["dep:uuid", "dep:serde"] +time = ["dep:time"] +json = ["serialize", "serde_json"] macros = ["mlua_derive/macros"] unstable = [] @@ -51,8 +54,11 @@ num-traits = { version = "0.2.14" } rustc-hash = "2.0" futures-util = { version = "0.3", optional = true, default-features = false, features = ["std"] } serde = { version = "1.0", optional = true } +serde_json = { version = "1.0", optional = true} +uuid = { version = "1.10.0", optional = true, features = ["v7", "serde"]} erased-serde = { version = "0.4", optional = true } serde-value = { version = "0.7", optional = true } +time = {version = "0.3.36", optional = true, features = ["parsing"]} parking_lot = { version = "0.12", optional = true } ffi = { package = "mlua-sys", version = "0.6.1", path = "mlua-sys" } diff --git a/mlua_derive/src/from_lua_table.rs b/mlua_derive/src/from_lua_table.rs new file mode 100644 index 00000000..ef5969fe --- /dev/null +++ b/mlua_derive/src/from_lua_table.rs @@ -0,0 +1,45 @@ +use proc_macro::TokenStream; +use quote::{quote, format_ident}; +use syn::{parse_macro_input, DeriveInput, Data, Fields}; + +pub fn from_lua_table(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let ident = input.ident; + + let fields = if let Data::Struct(data_struct) = input.data { + match data_struct.fields { + Fields::Named(fields) => fields, + _ => panic!("FromLuaTable can only be derived for structs with named fields"), + } + } else { + panic!("FromLuaTable can only be derived for structs"); + }; + + let get_fields = fields.named.iter().map(|field| { + let name = &field.ident; + let name_str = name.as_ref().unwrap().to_string(); + quote! { + #name: table.get(#name_str)?, + } + }); + + let gen = quote! { + impl<'lua> ::mlua::FromLua<'lua> for #ident { + fn from_lua(lua_value: ::mlua::Value<'lua>, lua: &'lua ::mlua::Lua) -> ::mlua::Result { + if let ::mlua::Value::Table(table) = lua_value { + Ok(Self { + #(#get_fields)* + }) + } else { + Err(::mlua::Error::FromLuaConversionError { + from: lua_value.type_name(), + to: stringify!(#ident), + message: Some(String::from("expected a Lua table")), + }) + } + } + } + }; + + gen.into() +} diff --git a/mlua_derive/src/lib.rs b/mlua_derive/src/lib.rs index 74605cb9..c9ab44eb 100644 --- a/mlua_derive/src/lib.rs +++ b/mlua_derive/src/lib.rs @@ -153,9 +153,33 @@ pub fn from_lua(input: TokenStream) -> TokenStream { from_lua::from_lua(input) } +#[cfg(feature = "macros")] +#[proc_macro_derive(ToLua)] +pub fn to_lua(input: TokenStream) -> TokenStream { + to_lua::to_lua(input) +} + +#[cfg(feature = "macros")] +#[proc_macro_derive(FromLuaTable)] +pub fn from_lua_table(input: TokenStream) -> TokenStream { + from_lua_table::from_lua_table(input) +} + +#[cfg(feature = "macros")] +#[proc_macro_derive(ToLuaTable)] +pub fn to_lua_table(input: TokenStream) -> TokenStream { + to_lua_table::to_lua_table(input) +} + #[cfg(feature = "macros")] mod chunk; #[cfg(feature = "macros")] mod from_lua; #[cfg(feature = "macros")] +mod to_lua; +#[cfg(feature = "macros")] +mod from_lua_table; +#[cfg(feature = "macros")] +mod to_lua_table; +#[cfg(feature = "macros")] mod token; diff --git a/mlua_derive/src/to_lua.rs b/mlua_derive/src/to_lua.rs new file mode 100644 index 00000000..984185ee --- /dev/null +++ b/mlua_derive/src/to_lua.rs @@ -0,0 +1,77 @@ +use proc_macro::TokenStream; +use quote::{quote, format_ident}; +use syn::{parse_macro_input, DeriveInput, Data, Fields, Type}; +use syn::spanned::Spanned; +use syn::punctuated::Punctuated; +use proc_macro2::TokenStream as TokenStream2; + +pub fn to_lua(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let ident = input.ident; + + let fields = if let Data::Struct(data_struct) = input.data { + match data_struct.fields { + Fields::Named(fields) => fields, + _ => panic!("ToLua can only be derived for structs with named fields"), + } + } else { + panic!("ToLua can only be derived for structs"); + }; + + let add_field_methods = fields.named.iter().map(|field| { + let name = &field.ident; + let name_str = name.as_ref().unwrap().to_string(); + let ty = &field.ty; + + let get_method = if is_copy_type(ty) { + quote! { + fields.add_field_method_get(#name_str, |_, this| Ok(this.#name)); + } + } else { + quote! { + fields.add_field_method_get(#name_str, |_, this| Ok(this.#name.clone())); + } + }; + + let set_method = quote! { + fields.add_field_method_set(#name_str, |_, this, val| { + this.#name = val; + Ok(()) + }); + }; + + quote! { + #get_method + #set_method + } + }); + + let gen = quote! { + impl ::mlua::UserData for #ident { + fn add_fields<'lua, F: ::mlua::prelude::LuaUserDataFields<'lua, Self>>(fields: &mut F) { + #(#add_field_methods)* + } + } + }; + + gen.into() +} + +// I don't know how to determine whether or not something implements copy, so for now everything +// will be cloned that isn't one of these copyable primitives. +fn is_copy_type(ty: &Type) -> bool { + match ty { + Type::Path(type_path) => { + let segments = &type_path.path.segments; + let segment = segments.last().unwrap(); + match segment.ident.to_string().as_str() { + "u8" | "u16" | "u32" | "u64" | "u128" | + "i8" | "i16" | "i32" | "i64" | "i128" | + "f32" | "f64" | + "bool" | "char" | "usize" | "isize" => true, + _ => false, + } + } + _ => false, + } +} diff --git a/mlua_derive/src/to_lua_table.rs b/mlua_derive/src/to_lua_table.rs new file mode 100644 index 00000000..506f08d9 --- /dev/null +++ b/mlua_derive/src/to_lua_table.rs @@ -0,0 +1,37 @@ +use proc_macro::TokenStream; +use quote::{quote, format_ident}; +use syn::{parse_macro_input, DeriveInput, Data, Fields}; + +pub fn to_lua_table(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let ident = input.ident; + + let fields = if let Data::Struct(data_struct) = input.data { + match data_struct.fields { + Fields::Named(fields) => fields, + _ => panic!("ToLua can only be derived for structs with named fields"), + } + } else { + panic!("ToLua can only be derived for structs"); + }; + + let set_fields = fields.named.iter().map(|field| { + let name = &field.ident; + let name_str = name.as_ref().unwrap().to_string(); + quote! { + table.set(#name_str, self.#name)?; + } + }); + + let gen = quote! { + impl<'lua> ::mlua::IntoLua<'lua> for #ident { + fn into_lua(self, lua: &'lua ::mlua::Lua) -> ::mlua::Result<::mlua::Value<'lua>> { + let table = lua.create_table()?; + #(#set_fields)* + Ok(::mlua::Value::Table(table)) + } + } + }; + + gen.into() +} diff --git a/src/conversion.rs b/src/conversion.rs index 6013ebf9..76b76e00 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -520,6 +520,171 @@ impl<'lua> FromLua<'lua> for LightUserData { } } +#[cfg(feature = "time")] +impl<'lua> IntoLua<'lua> for time::OffsetDateTime { + fn into_lua(self, lua: &'lua Lua) -> Result> { + let datetime_str = self.format(&time::format_description::well_known::Rfc3339).map_err(|e| Error::RuntimeError(e.to_string()))?; + let lua_string = lua.create_string(&datetime_str)?; + Ok(Value::String(lua_string)) + } +} + +#[cfg(feature = "time")] +impl<'lua> FromLua<'lua> for time::OffsetDateTime { + fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result { + match value { + Value::String(lua_string) => { + let datetime_str = lua_string.to_str()?; + time::OffsetDateTime::parse(datetime_str, &time::format_description::well_known::Rfc3339).map_err(|e| Error::FromLuaConversionError { + from: "string", + to: "time::OffsetDateTime", + message: Some(e.to_string()), + }) + }, + _ => Err(Error::FromLuaConversionError { + from: value.type_name(), + to: "time::OffsetDateTime", + message: Some("Expected a string".to_string()), + }), + } + } +} + +#[cfg(feature = "json")] +impl<'lua> IntoLua<'lua> for serde_json::Value { + #[inline] + fn into_lua(self, lua: &'lua Lua) -> Result> { + match self { + serde_json::Value::Null => Ok(Value::Nil), + serde_json::Value::Bool(b) => Ok(Value::Boolean(b)), + serde_json::Value::Number(n) => { + if let Some(i) = n.as_i64() { + Ok(Value::Integer(i)) + } else if let Some(f) = n.as_f64() { + Ok(Value::Number(f)) + } else { + Err(Error::FromLuaConversionError { + from: "number", + to: "Value", + message: Some("Invalid number".to_string()), + }) + } + }, + serde_json::Value::String(s) => { + let lua_string = lua.create_string(&s)?; + Ok(Value::String(lua_string)) + }, + serde_json::Value::Array(arr) => { + let lua_table = lua.create_table()?; + for (i, value) in arr.into_iter().enumerate() { + lua_table.set(i + 1, value.into_lua(lua)?)?; + } + Ok(Value::Table(lua_table)) + }, + serde_json::Value::Object(obj) => { + let lua_table = lua.create_table()?; + for (key, value) in obj { + lua_table.set(key, value.into_lua(lua)?)?; + } + Ok(Value::Table(lua_table)) + }, + } + } +} + +#[cfg(feature = "json")] +impl<'lua> FromLua<'lua> for serde_json::Value { + fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result { + let ty = value.type_name(); + serde_json::to_value(value).map_err(|e| Error::FromLuaConversionError { + from: ty, + to: "serde_json::Value", + message: Some(format!("{}", e)), + }) + } +} + +#[cfg(feature = "uuid")] +impl<'lua> FromLua<'lua> for uuid::Uuid { + #[inline] + fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result { + let ty = value.type_name(); + let string_result = lua.coerce_string(value)? + .ok_or_else(|| Error::FromLuaConversionError { + from: ty, + to: "string", + message: Some("expected string uuid".to_string()), + }); + match string_result { + Ok(string) => { + match uuid::Uuid::parse_str(string.to_str()?) { + Ok(val) => Ok(val), + Err(_) => Err(Error::FromLuaConversionError { + from: "string", + to: "uuid::Uuid", + message: Some("failed to parse UUID".to_string()), + }) + } + }, + Err(e) => Err(e) + } + } +} + +#[cfg(feature = "uuid")] +impl<'lua> IntoLua<'lua> for uuid::Uuid { + #[inline] + fn into_lua(self, lua: &'lua Lua) -> Result> { + let uuid_string = lua.create_string(self.to_string().as_str())?; + Ok(Value::String(uuid_string)) + } +} + +#[cfg(feature = "uuid")] +impl<'lua> IntoLua<'lua> for &uuid::Uuid { + #[inline] + fn into_lua(self, lua: &'lua Lua) -> Result> { + let uuid_string = lua.create_string(self.to_string().as_str())?; + Ok(Value::String(uuid_string)) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> { + let uuid_string = lua.create_string(self.to_string().as_str())?; + lua.push_ref(&uuid_string.0); + Ok(()) + } +} + + +// impl<'lua> FromLua<'lua> for Value<'lua> { +// #[inline] +// fn from_lua(lua_value: Value<'lua>, _: &'lua Lua) -> Result { +// Ok(lua_value) +// } +// } + +// impl<'lua> IntoLua<'lua> for String<'lua> { +// #[inline] +// fn into_lua(self, _: &'lua Lua) -> Result> { +// Ok(Value::String(self)) +// } +// } + +// impl<'lua> IntoLua<'lua> for &String<'lua> { +// #[inline] +// fn into_lua(self, _: &'lua Lua) -> Result> { +// Ok(Value::String(self.clone())) +// } + +// #[inline] +// unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> { +// lua.push_ref(&self.0); +// Ok(()) +// } +// } + + #[cfg(feature = "luau")] impl<'lua> IntoLua<'lua> for crate::types::Vector { #[inline] diff --git a/src/lib.rs b/src/lib.rs index 9345f5f9..e9b18c77 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -225,6 +225,21 @@ pub use mlua_derive::chunk; #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] pub use mlua_derive::FromLua; +/// Derive [`ToLua`] for a Rust type. +/// +/// Nested types require [`IntoLua`] as well +#[cfg(feature = "macros")] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +pub use mlua_derive::ToLua; + +#[cfg(feature = "macros")] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +pub use mlua_derive::ToLuaTable; + +#[cfg(feature = "macros")] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +pub use mlua_derive::FromLuaTable; + /// Registers Lua module entrypoint. /// /// You can register multiple entrypoints as required.