diff --git a/serde/src/de/value.rs b/serde/src/de/value.rs index 1ec947786..c5c1e9573 100644 --- a/serde/src/de/value.rs +++ b/serde/src/de/value.rs @@ -24,7 +24,9 @@ use crate::lib::*; use self::private::{First, Second}; -use crate::de::{self, size_hint, Deserializer, Expected, IntoDeserializer, SeqAccess, Visitor}; +use crate::de::{ + self, size_hint, Deserializer, Expected, IgnoredAny, IntoDeserializer, SeqAccess, Visitor, +}; use crate::ser; //////////////////////////////////////////////////////////////////////////////// @@ -978,7 +980,9 @@ where } } -struct ExpectedInSeq(usize); +/// Number of elements still expected in a sequence. Does not include already +/// read elements. +pub(crate) struct ExpectedInSeq(pub usize); impl Expected for ExpectedInSeq { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { @@ -1076,9 +1080,38 @@ where visitor.visit_seq(self.seq) } + fn deserialize_unit(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + // Covered by tests/test_enum_internally_tagged.rs + // newtype_unit + visitor.visit_unit() + } + + fn deserialize_unit_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + // Covered by tests/test_enum_internally_tagged.rs + // newtype_unit_struct + self.deserialize_unit(visitor) + } + + fn deserialize_newtype_struct(self, _name: &str, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + forward_to_deserialize_any! { bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string - bytes byte_buf option unit unit_struct newtype_struct seq tuple + bytes byte_buf option seq tuple tuple_struct map struct enum identifier ignored_any } } @@ -1406,6 +1439,8 @@ where } } +/// Number of elements still expected in a map. Does not include already read +/// elements. struct ExpectedInMap(usize); impl Expected for ExpectedInMap { @@ -1479,6 +1514,42 @@ where visitor.visit_map(self.map) } + fn deserialize_unit(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + // Covered by tests/test_enum_internally_tagged.rs + // newtype_unit + tri!(IgnoredAny.visit_map(self.map)); + visitor.visit_unit() + } + + fn deserialize_unit_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + // Covered by tests/test_enum_internally_tagged.rs + // newtype_unit_struct + self.deserialize_unit(visitor) + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + // Covered by tests/test_enum_internally_tagged.rs + // newtype_newtype + visitor.visit_newtype_struct(self) + } + fn deserialize_enum( self, _name: &str, @@ -1493,7 +1564,7 @@ where forward_to_deserialize_any! { bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string - bytes byte_buf option unit unit_struct newtype_struct seq tuple + bytes byte_buf option seq tuple tuple_struct map struct identifier ignored_any } } diff --git a/serde/src/private/de.rs b/serde/src/private/de.rs index 50ae6ed15..1e3b58a26 100644 --- a/serde/src/private/de.rs +++ b/serde/src/private/de.rs @@ -209,7 +209,9 @@ mod content { use crate::lib::*; use crate::actually_private; - use crate::de::value::{MapDeserializer, SeqDeserializer}; + use crate::de::value::{ + ExpectedInSeq, MapAccessDeserializer, MapDeserializer, SeqDeserializer, + }; use crate::de::{ self, size_hint, Deserialize, DeserializeSeed, Deserializer, EnumAccess, Expected, IgnoredAny, MapAccess, SeqAccess, Unexpected, Visitor, @@ -536,9 +538,7 @@ mod content { } /// This is the type of the map keys in an internally tagged enum. - /// - /// Not public API. - pub enum TagOrContent<'de> { + enum TagOrContent<'de> { Tag, Content(Content<'de>), } @@ -855,9 +855,9 @@ mod content { impl<'de, T> Visitor<'de> for TaggedContentVisitor where - T: Deserialize<'de>, + T: Deserialize<'de> + DeserializeSeed<'de>, { - type Value = (T, Content<'de>); + type Value = T::Value; fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result { fmt.write_str(self.expecting) @@ -867,42 +867,63 @@ mod content { where S: SeqAccess<'de>, { - let tag = match tri!(seq.next_element()) { + let tag: T = match tri!(seq.next_element()) { Some(tag) => tag, None => { return Err(de::Error::missing_field(self.tag_name)); } }; - let rest = de::value::SeqAccessDeserializer::new(seq); - Ok((tag, tri!(Content::deserialize(rest)))) + tag.deserialize(de::value::SeqAccessDeserializer::new(seq)) } fn visit_map(self, mut map: M) -> Result where M: MapAccess<'de>, { - let mut tag = None; - let mut vec = Vec::<(Content, Content)>::with_capacity(size_hint::cautious::<( - Content, - Content, - )>(map.size_hint())); - while let Some(k) = tri!(map.next_key_seed(TagOrContentVisitor::new(self.tag_name))) { - match k { - TagOrContent::Tag => { - if tag.is_some() { - return Err(de::Error::duplicate_field(self.tag_name)); + // Read the first field. If it is a tag, immediately deserialize the typed data. + // Otherwise, we collect everything until we find the tag, and then deserialize + // using ContentDeserializer. + match tri!(map.next_key_seed(TagOrContentVisitor::new(self.tag_name))) { + Some(TagOrContent::Tag) => { + let tag: T = tri!(map.next_value()); + tag.deserialize(MapAccessDeserializer::new(map)) + } + Some(TagOrContent::Content(key)) => { + let mut tag = None::; + let mut vec = Vec::<(Content, Content)>::with_capacity(size_hint::cautious::<( + Content, + Content, + )>( + map.size_hint() + )); + + let v = tri!(map.next_value()); + vec.push((key, v)); + + while let Some(k) = + tri!(map.next_key_seed(TagOrContentVisitor::new(self.tag_name))) + { + match k { + TagOrContent::Tag => { + if tag.is_some() { + return Err(de::Error::duplicate_field(self.tag_name)); + } + tag = Some(tri!(map.next_value())); + } + TagOrContent::Content(k) => { + let v = tri!(map.next_value()); + vec.push((k, v)); + } } - tag = Some(tri!(map.next_value())); } - TagOrContent::Content(k) => { - let v = tri!(map.next_value()); - vec.push((k, v)); + match tag { + None => Err(de::Error::missing_field(self.tag_name)), + Some(tag) => { + tag.deserialize(ContentDeserializer::::new(Content::Map(vec))) + } } } - } - match tag { None => Err(de::Error::missing_field(self.tag_name)), - Some(tag) => Ok((tag, Content::Map(vec))), } } } @@ -2296,11 +2317,17 @@ mod content { ) } - fn visit_seq(self, _: S) -> Result<(), S::Error> + fn visit_seq(self, mut seq: S) -> Result<(), S::Error> where S: SeqAccess<'de>, { - Ok(()) + match tri!(seq.next_element()) { + Some(IgnoredAny) => Err(de::Error::invalid_length( + 1 + seq.size_hint().unwrap_or(0), + &ExpectedInSeq(0), + )), + None => Ok(()), + } } fn visit_map(self, mut access: M) -> Result<(), M::Error> diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index 518f84320..7ddfdd9a0 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -1067,7 +1067,7 @@ fn deserialize_struct( _serde::de::VariantAccess::struct_variant(__variant, FIELDS, #visitor_expr) }, StructForm::InternallyTagged(_, deserializer) => quote! { - _serde::Deserializer::deserialize_any(#deserializer, #visitor_expr) + _serde::Deserializer::deserialize_map(#deserializer, #visitor_expr) }, StructForm::Untagged(_, deserializer) => quote! { _serde::Deserializer::deserialize_any(#deserializer, #visitor_expr) @@ -1397,19 +1397,55 @@ fn deserialize_internally_tagged_enum( let expecting = format!("internally tagged enum {}", params.type_name()); let expecting = cattrs.expecting().unwrap_or(&expecting); + let this_type = ¶ms.this_type; + let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = + split_with_de_lifetime(params); + let delife = params.borrowed.de_lifetime(); + quote_block! { #variant_visitor #variants_stmt - let (__tag, __content) = _serde::Deserializer::deserialize_any( - __deserializer, - _serde::__private::de::TaggedContentVisitor::<__Field>::new(#tag, #expecting))?; - let __deserializer = _serde::__private::de::ContentDeserializer::<__D::Error>::new(__content); + struct __Seed #de_impl_generics #where_clause { + tag: __Field, + marker: _serde::__private::PhantomData<#this_type #ty_generics>, + lifetime: _serde::__private::PhantomData<&#delife ()>, + } - match __tag { - #(#variant_arms)* + impl #de_impl_generics _serde::de::Deserialize<#delife> for __Seed #de_ty_generics #where_clause { + fn deserialize<__D>(__deserializer: __D) -> _serde::__private::Result + where + __D: _serde::de::Deserializer<#delife>, + { + _serde::__private::Result::map( + __Field::deserialize(__deserializer), + |__tag| __Seed { + tag: __tag, + marker: _serde::__private::PhantomData, + lifetime: _serde::__private::PhantomData, + } + ) + } } + + impl #de_impl_generics _serde::de::DeserializeSeed<#delife> for __Seed #de_ty_generics #where_clause { + type Value = #this_type #ty_generics; + + fn deserialize<__D>(self, __deserializer: __D) -> _serde::__private::Result + where + __D: _serde::de::Deserializer<#delife>, + { + match self.tag { + #(#variant_arms)* + } + } + } + + _serde::Deserializer::deserialize_map( + __deserializer, + _serde::__private::de::TaggedContentVisitor::<__Seed>::new(#tag, #expecting) + ) } } @@ -1862,7 +1898,7 @@ fn deserialize_internally_tagged_variant( quote!((#default)) }); quote_block! { - _serde::Deserializer::deserialize_any(#deserializer, _serde::__private::de::InternallyTaggedUnitVisitor::new(#type_name, #variant_name))?; + _serde::Deserializer::deserialize_map(#deserializer, _serde::__private::de::InternallyTaggedUnitVisitor::new(#type_name, #variant_name))?; _serde::__private::Ok(#this_value::#variant_ident #default) } } diff --git a/test_suite/tests/test_enum_internally_tagged.rs b/test_suite/tests/test_enum_internally_tagged.rs index b4d428c4d..c0f0ae105 100644 --- a/test_suite/tests/test_enum_internally_tagged.rs +++ b/test_suite/tests/test_enum_internally_tagged.rs @@ -320,8 +320,9 @@ fn newtype_map() { Token::Seq { len: Some(2) }, Token::Str("NewtypeMap"), // tag Token::Map { len: Some(0) }, - Token::MapEnd, - Token::SeqEnd, + // Tokens that could follow, but assert_de_tokens_error does not want them + // Token::MapEnd, + // Token::SeqEnd, ], "invalid type: sequence, expected a map", );