From 20d6aeba88b1afb5219727a61da41e210ec60c3b Mon Sep 17 00:00:00 2001 From: kgtkr Date: Fri, 5 Aug 2022 19:56:55 +0900 Subject: [PATCH 1/2] Support enum --- src/de.rs | 189 ++++++++++++++++++++++++++++++++++++++++++++++++++- src/value.rs | 4 ++ 2 files changed, 192 insertions(+), 1 deletion(-) diff --git a/src/de.rs b/src/de.rs index 082f3a1..d876fd1 100644 --- a/src/de.rs +++ b/src/de.rs @@ -43,6 +43,7 @@ impl<'de> de::Deserializer<'de> for Deserializer { where V: Visitor<'de>, { + // Maybe it should be an unconditional error. Because the error message is confusing. self.deserialize_map(vis) } @@ -106,7 +107,7 @@ impl<'de> de::Deserializer<'de> for Deserializer { forward_to_deserialize_any! { unit unit_struct - tuple_struct enum ignored_any + tuple_struct ignored_any } fn deserialize_u32(self, vis: V) -> Result @@ -286,6 +287,25 @@ impl<'de> de::Deserializer<'de> for Deserializer { self.deserialize_string(vis) } + + fn deserialize_enum( + self, + name: &'static str, + variants: &'static [&'static str], + vis: V, + ) -> Result + where + V: Visitor<'de>, + { + debug!( + "deserialize enum: name: {} variants: {:?} from {:#?}", + name, variants, self.0 + ); + + let keys = variants.iter().map(|v| v.to_string()).collect(); + debug!("flatten keys: {:?}", keys); + vis.visit_enum(EnumAccessor::new(keys, self.0)) + } } struct SeqAccessor { @@ -374,6 +394,87 @@ impl<'de> de::MapAccess<'de> for MapAccessor { } } +struct EnumAccessor { + keys: std::vec::IntoIter, + node: Node, +} + +impl EnumAccessor { + fn new(keys: Vec, node: Node) -> Self { + debug!("access keys {:?} from enum", keys); + + Self { + keys: keys.into_iter(), + node, + } + } +} + +impl<'de> de::EnumAccess<'de> for EnumAccessor { + type Error = Error; + type Variant = VariantAccessor; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: DeserializeSeed<'de>, + { + let key = self + .keys + .into_iter() + .find(|key| self.node.value() == key) + .ok_or_else(|| de::Error::custom("no variant found"))?; + + let variant = VariantAccessor::new(self.node); + Ok((seed.deserialize(key.into_deserializer())?, variant)) + } +} + +struct VariantAccessor { + node: Node, +} + +impl VariantAccessor { + fn new(node: Node) -> Self { + Self { node } + } +} + +impl<'de> de::VariantAccess<'de> for VariantAccessor { + type Error = Error; + + fn unit_variant(self) -> Result<(), Self::Error> { + if self.node.has_children() { + return Err(de::Error::custom("variant is not unit")); + } + Ok(()) + } + fn newtype_variant_seed(self, seed: T) -> Result + where + T: DeserializeSeed<'de>, + { + seed.deserialize(Deserializer(self.node)) + } + fn tuple_variant(self, _len: usize, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(de::Error::custom("tuple variant is not supported")) + } + fn struct_variant( + self, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + debug!("deserialize struct variant: fields: {:?}", fields); + let keys = fields.iter().map(|v| v.to_string()).collect(); + debug!("flatten keys: {:?}", keys); + visitor.visit_map(MapAccessor::new(keys, self.node)) + } +} + #[cfg(test)] mod tests { use serde::Deserialize; @@ -606,4 +707,90 @@ mod tests { assert_eq!(t["metasrv_log_level"], "DEBUG".to_string()) }) } + + #[derive(Deserialize, PartialEq, Debug)] + struct EnumNewtype { + bar: String, + } + + #[derive(Deserialize, PartialEq, Debug)] + struct ExternallyEnumStruct { + foo: ExternallyEnum, + } + + #[derive(Deserialize, PartialEq, Debug)] + enum ExternallyEnum { + X, + Y(EnumNewtype), + Z { a: i32 }, + } + + #[test] + fn test_from_env_externally_enum() { + let _ = env_logger::try_init(); + + temp_env::with_vars(vec![("FOO", Some("X"))], || { + let t: ExternallyEnumStruct = from_env().expect("must success"); + assert_eq!(t.foo, ExternallyEnum::X) + }); + + temp_env::with_vars(vec![("FOO", Some("Y")), ("FOO_BAR", Some("xxx"))], || { + let t: ExternallyEnumStruct = from_env().expect("must success"); + assert_eq!( + t.foo, + ExternallyEnum::Y(EnumNewtype { + bar: "xxx".to_string() + }) + ) + }); + + temp_env::with_vars(vec![("FOO", Some("Z")), ("FOO_A", Some("1"))], || { + let t: ExternallyEnumStruct = from_env().expect("must success"); + assert_eq!(t.foo, ExternallyEnum::Z { a: 1 }) + }); + } + + #[derive(Deserialize, PartialEq, Debug)] + struct InternallyEnumStruct { + foo: InternallyEnum, + } + + #[derive(Deserialize, PartialEq, Debug)] + #[serde(tag = "type")] + enum InternallyEnum { + X, + Y(EnumNewtype), + Z { a: i32 }, + } + + // Currently Internally / Adjacently / Untagged enum is not support by the following issues + // https://github.com/serde-rs/serde/issues/2187 + #[test] + #[ignore] + fn test_from_env_internally_enum() { + let _ = env_logger::try_init(); + + temp_env::with_vars(vec![("FOO_TYPE", Some("X"))], || { + let t: InternallyEnumStruct = from_env().expect("must success"); + assert_eq!(t.foo, InternallyEnum::X) + }); + + temp_env::with_vars( + vec![("FOO_TYPE", Some("Y")), ("FOO_BAR", Some("xxx"))], + || { + let t: InternallyEnumStruct = from_env().expect("must success"); + assert_eq!( + t.foo, + InternallyEnum::Y(EnumNewtype { + bar: "xxx".to_string() + }) + ) + }, + ); + + temp_env::with_vars(vec![("FOO_TYPE", Some("Z")), ("FOO_A", Some("1"))], || { + let t: InternallyEnumStruct = from_env().expect("must success"); + assert_eq!(t.foo, InternallyEnum::Z { a: 1 }) + }); + } } diff --git a/src/value.rs b/src/value.rs index 92f8a2a..7fdf141 100644 --- a/src/value.rs +++ b/src/value.rs @@ -46,6 +46,10 @@ impl Node { self.0 } + pub fn has_children(&self) -> bool { + !self.1.is_empty() + } + pub fn flatten(&self, prefix: &str) -> Vec { let mut m = Vec::new(); From d93270ba3e9aa8701e3ebda01a2ed5bed3294db4 Mon Sep 17 00:00:00 2001 From: kgtkr Date: Fri, 5 Aug 2022 21:23:43 +0900 Subject: [PATCH 2/2] return error, when call deserialize_any --- src/de.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/de.rs b/src/de.rs index d876fd1..b99ccb3 100644 --- a/src/de.rs +++ b/src/de.rs @@ -39,12 +39,11 @@ struct Deserializer(Node); impl<'de> de::Deserializer<'de> for Deserializer { type Error = Error; - fn deserialize_any(self, vis: V) -> Result + fn deserialize_any(self, _vis: V) -> Result where V: Visitor<'de>, { - // Maybe it should be an unconditional error. Because the error message is confusing. - self.deserialize_map(vis) + Err(de::Error::custom("deserialize_any is not supported")) } fn deserialize_bool(self, vis: V) -> Result