From 05da6e667a4741175d9df90f7a4abe66066dd4ab Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Tue, 17 Dec 2024 15:28:17 -0800 Subject: [PATCH] Support enum variants that have aliases (#53) * Support enum variants that have aliases * Return error instead of panic on enum with 0 variants Before PR 53: thread 'main' panicked at /git/tmp/serde-reflection/serde-reflection/src/de.rs:431:18: variant indexes must be a non-empty range 0..variants.len() First draft of PR 53, debug mode (overflow checks): thread 'main' panicked at /git/tmp/serde-reflection/serde-reflection/src/de.rs:435:42: attempt to subtract with overflow First draft of PR 53, release mode: Failed to deserialize value: "invalid value: integer `0`, expected variant index 0 <= i < 0" This commit: Not supported: deserialize_enum with 0 variants --- Cargo.lock | 33 ++++-- serde-reflection/Cargo.toml | 6 +- serde-reflection/src/de.rs | 181 +++++++++++++++++++++++++------- serde-reflection/src/trace.rs | 38 +++++-- serde-reflection/tests/serde.rs | 5 +- 5 files changed, 207 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8e947fae9..3ae9cbe2b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -101,6 +101,15 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b" +[[package]] +name = "erased-discriminant" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "373f74784b51403f16c5fa6dc667488389e629811329c1c6719c25874da2ba4f" +dependencies = [ + "typeid", +] + [[package]] name = "errno" version = "0.3.9" @@ -433,9 +442,9 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "serde" -version = "1.0.203" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +checksum = "3ea7893ff5e2466df8d720bb615088341b295f849602c6956047f8f80f0e9bc1" dependencies = [ "serde_derive", ] @@ -489,12 +498,14 @@ name = "serde-reflection" version = "0.4.0" dependencies = [ "bincode", + "erased-discriminant", "once_cell", "serde", "serde_bytes", "serde_json", "serde_yaml", "thiserror", + "typeid", ] [[package]] @@ -508,13 +519,13 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.203" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +checksum = "7e85ad2009c50b58e87caa8cd6dac16bdf511bbfb7af6c33df902396aa480fa5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.85", ] [[package]] @@ -595,9 +606,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.66" +version = "2.0.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" +checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" dependencies = [ "proc-macro2", "quote", @@ -652,9 +663,15 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.85", ] +[[package]] +name = "typeid" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e13db2e0ccd5e14a544e8a246ba2312cd25223f616442d7f2cb0e3db614236e" + [[package]] name = "unicode-ident" version = "1.0.12" diff --git a/serde-reflection/Cargo.toml b/serde-reflection/Cargo.toml index 24b90452a..2f7979b5d 100644 --- a/serde-reflection/Cargo.toml +++ b/serde-reflection/Cargo.toml @@ -17,9 +17,11 @@ exclude = [ ] [dependencies] -thiserror = "1.0.25" -serde = { version = "1.0.126", features = ["derive"] } +erased-discriminant = "1" once_cell = "1.7.2" +serde = { version = "1.0.126", features = ["derive"] } +thiserror = "1.0.25" +typeid = "1" [dev-dependencies] bincode = "1.3.3" diff --git a/serde-reflection/src/de.rs b/serde-reflection/src/de.rs index 9ee16213a..82ec0290f 100644 --- a/serde-reflection/src/de.rs +++ b/serde-reflection/src/de.rs @@ -4,11 +4,16 @@ use crate::{ error::{Error, Result}, format::{ContainerFormat, ContainerFormatEntry, Format, FormatHolder, Named, VariantFormat}, - trace::{Samples, Tracer}, + trace::{EnumProgress, Samples, Tracer, VariantId}, value::IntoSeqDeserializer, }; -use serde::de::{self, DeserializeSeed, IntoDeserializer, Visitor}; -use std::collections::BTreeMap; +use erased_discriminant::Discriminant; +use serde::de::{ + self, + value::{BorrowedStrDeserializer, U32Deserializer}, + DeserializeSeed, IntoDeserializer, Visitor, +}; +use std::collections::btree_map::{BTreeMap, Entry}; /// Deserialize a single value. /// * The lifetime 'a is set by the deserialization call site and the @@ -391,55 +396,151 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { // Assumption: The first variant(s) should be "base cases", i.e. not cause infinite recursion // while constructing sample values. + #[allow(clippy::map_entry)] // false positive https://github.com/rust-lang/rust-clippy/issues/9470 fn deserialize_enum( self, - name: &'static str, + enum_name: &'static str, variants: &'static [&'static str], visitor: V, ) -> Result where V: Visitor<'de>, { - self.format.unify(Format::TypeName(name.into()))?; + if variants.is_empty() { + return Err(Error::NotSupported("deserialize_enum with 0 variants")); + } + + let enum_type_id = typeid::of::(); + self.format.unify(Format::TypeName(enum_name.into()))?; // Pre-update the registry. self.tracer .registry - .entry(name.to_string()) + .entry(enum_name.to_string()) .unify(ContainerFormat::Enum(BTreeMap::new()))?; - let known_variants = match self.tracer.registry.get_mut(name) { + let known_variants = match self.tracer.registry.get_mut(enum_name) { Some(ContainerFormat::Enum(x)) => x, _ => unreachable!(), }; - // If we have found all the variants OR if the enum is marked as - // incomplete already, pick the first index. - let index = if known_variants.len() == variants.len() - || self.tracer.incomplete_enums.contains(name) - { - 0 - } else { - let mut index = known_variants.len() as u32; - // Scan the range 0..=known_variants.len() downwards to find the next - // variant index to explore. - while known_variants.contains_key(&index) { - index -= 1; + + // If the enum is marked as incomplete, just visit the first index + // because we presume it avoids recursion. + if self.tracer.incomplete_enums.contains_key(enum_name) { + return visitor.visit_enum(EnumDeserializer::new( + self.tracer, + self.samples, + VariantId::Index(0), + &mut VariantFormat::unknown(), + )); + } + + // First visit each of the variants by name according to `variants`. + // Later revisit them by u32 index until an index matching each of the + // named variants has been determined. + let provisional_min = u32::MAX - (variants.len() - 1) as u32; + for (i, &variant_name) in variants.iter().enumerate() { + if !self + .tracer + .discriminants + .contains_key(&(enum_type_id, VariantId::Name(variant_name))) + { + // Insert into known_variants with a provisional index. + let provisional_index = provisional_min + i as u32; + let variant = known_variants + .entry(provisional_index) + .or_insert_with(|| Named { + name: variant_name.to_owned(), + value: VariantFormat::unknown(), + }); + self.tracer + .incomplete_enums + .insert(enum_name.into(), EnumProgress::NamedVariantsRemaining); + // Compute the discriminant and format for this variant. + let mut value = variant.value.clone(); + let enum_value = visitor.visit_enum(EnumDeserializer::new( + self.tracer, + self.samples, + VariantId::Name(variant_name), + &mut value, + ))?; + let discriminant = Discriminant::of(&enum_value); + self.tracer + .discriminants + .insert((enum_type_id, VariantId::Name(variant_name)), discriminant); + return Ok(enum_value); } - index + } + + // We know the discriminant for every variant name. Now visit them again + // by index to find the u32 id that goes with each name. + // + // If there are no provisional entries waiting for an index, just go + // with index 0. + let mut index = 0; + if known_variants.range(provisional_min..).next().is_some() { + self.tracer + .incomplete_enums + .insert(enum_name.into(), EnumProgress::IndexedVariantsRemaining); + while known_variants.contains_key(&index) + && self + .tracer + .discriminants + .contains_key(&(enum_type_id, VariantId::Index(index))) + { + index += 1; + } + } + + // Compute the discriminant and format for this variant. + let mut value = VariantFormat::unknown(); + let enum_value = visitor.visit_enum(EnumDeserializer::new( + self.tracer, + self.samples, + VariantId::Index(index), + &mut value, + ))?; + let discriminant = Discriminant::of(&enum_value); + self.tracer.discriminants.insert( + (enum_type_id, VariantId::Index(index)), + discriminant.clone(), + ); + self.tracer.incomplete_enums.remove(enum_name); + + // Rewrite provisional entries for which we now know a u32 index. + let known_variants = match self.tracer.registry.get_mut(enum_name) { + Some(ContainerFormat::Enum(x)) => x, + _ => unreachable!(), }; - let variant = known_variants.entry(index).or_insert_with(|| Named { - name: (*variants - .get(index as usize) - .expect("variant indexes must be a non-empty range 0..variants.len()")) - .to_string(), - value: VariantFormat::unknown(), - }); - let mut value = variant.value.clone(); - // Mark the enum as incomplete if this was not the last variant to explore. - if known_variants.len() != variants.len() { - self.tracer.incomplete_enums.insert(name.into()); + for provisional_index in provisional_min..=u32::MAX { + if let Entry::Occupied(provisional_entry) = known_variants.entry(provisional_index) { + if self.tracer.discriminants + [&(enum_type_id, VariantId::Name(&provisional_entry.get().name))] + == discriminant + { + let provisional_entry = provisional_entry.remove(); + match known_variants.entry(index) { + Entry::Vacant(vacant) => { + vacant.insert(provisional_entry); + } + Entry::Occupied(mut existing_entry) => { + // Discard the provisional entry's name and just + // keep the existing one. + existing_entry + .get_mut() + .value + .unify(provisional_entry.value)?; + } + } + } else { + self.tracer + .incomplete_enums + .insert(enum_name.into(), EnumProgress::IndexedVariantsRemaining); + } + } + } + if let Some(existing_entry) = known_variants.get_mut(&index) { + existing_entry.value.unify(value)?; } - // Compute the format for this variant. - let inner = EnumDeserializer::new(self.tracer, self.samples, index, &mut value); - visitor.visit_enum(inner) + Ok(enum_value) } fn deserialize_identifier(self, _visitor: V) -> Result @@ -539,7 +640,7 @@ where struct EnumDeserializer<'de, 'a> { tracer: &'a mut Tracer, samples: &'de Samples, - index: u32, + variant_id: VariantId<'static>, format: &'a mut VariantFormat, } @@ -547,13 +648,13 @@ impl<'de, 'a> EnumDeserializer<'de, 'a> { fn new( tracer: &'a mut Tracer, samples: &'de Samples, - index: u32, + variant_id: VariantId<'static>, format: &'a mut VariantFormat, ) -> Self { Self { tracer, samples, - index, + variant_id, format, } } @@ -567,8 +668,10 @@ impl<'de, 'a> de::EnumAccess<'de> for EnumDeserializer<'de, 'a> { where V: DeserializeSeed<'de>, { - let index = self.index; - let value = seed.deserialize(index.into_deserializer())?; + let value = match self.variant_id { + VariantId::Index(index) => seed.deserialize(U32Deserializer::new(index)), + VariantId::Name(name) => seed.deserialize(BorrowedStrDeserializer::new(name)), + }?; Ok((value, self)) } } diff --git a/serde-reflection/src/trace.rs b/serde-reflection/src/trace.rs index fe3347616..c99820782 100644 --- a/serde-reflection/src/trace.rs +++ b/serde-reflection/src/trace.rs @@ -8,9 +8,11 @@ use crate::{ ser::Serializer, value::Value, }; +use erased_discriminant::Discriminant; use once_cell::sync::Lazy; use serde::{de::DeserializeSeed, Deserialize, Serialize}; -use std::collections::{BTreeMap, BTreeSet}; +use std::any::TypeId; +use std::collections::BTreeMap; /// A map of container formats. pub type Registry = BTreeMap; @@ -28,7 +30,24 @@ pub struct Tracer { /// Enums that have detected to be yet incomplete (i.e. missing variants) /// while tracing deserialization. - pub(crate) incomplete_enums: BTreeSet, + pub(crate) incomplete_enums: BTreeMap, + + /// Discriminant associated with each variant of each enum. + pub(crate) discriminants: BTreeMap<(TypeId, VariantId<'static>), Discriminant>, +} + +#[derive(Copy, Clone, Debug)] +pub(crate) enum EnumProgress { + /// There are variant names that have not yet been traced. + NamedVariantsRemaining, + /// There are variant numbers that have not yet been traced. + IndexedVariantsRemaining, +} + +#[derive(Eq, PartialEq, Ord, PartialOrd, Debug)] +pub(crate) enum VariantId<'a> { + Index(u32), + Name(&'a str), } /// User inputs, aka "samples", recorded during serialization. @@ -169,7 +188,8 @@ impl Tracer { Self { config, registry: BTreeMap::new(), - incomplete_enums: BTreeSet::new(), + incomplete_enums: BTreeMap::new(), + discriminants: BTreeMap::new(), } } @@ -235,9 +255,12 @@ impl Tracer { let (format, value) = self.trace_type_once::(samples)?; values.push(value); if let Format::TypeName(name) = &format { - if self.incomplete_enums.contains(name) { + if let Some(&progress) = self.incomplete_enums.get(name) { // Restart the analysis to find more variants of T. self.incomplete_enums.remove(name); + if let EnumProgress::NamedVariantsRemaining = progress { + values.pop().unwrap(); + } continue; } } @@ -271,9 +294,12 @@ impl Tracer { let (format, value) = self.trace_type_once_with_seed(samples, seed.clone())?; values.push(value); if let Format::TypeName(name) = &format { - if self.incomplete_enums.contains(name) { + if let Some(&progress) = self.incomplete_enums.get(name) { // Restart the analysis to find more variants of T. self.incomplete_enums.remove(name); + if let EnumProgress::NamedVariantsRemaining = progress { + values.pop().unwrap(); + } continue; } } @@ -298,7 +324,7 @@ impl Tracer { Ok(registry) } else { Err(Error::MissingVariants( - self.incomplete_enums.into_iter().collect(), + self.incomplete_enums.into_keys().collect(), )) } } diff --git a/serde-reflection/tests/serde.rs b/serde-reflection/tests/serde.rs index c31f1b405..7a617d9dd 100644 --- a/serde-reflection/tests/serde.rs +++ b/serde-reflection/tests/serde.rs @@ -13,7 +13,10 @@ enum E { Unit, Newtype(u16), Tuple(u16, Option), - Struct { a: u32 }, + Struct { + a: u32, + }, + #[serde(alias = "NewTupleArray2")] NewTupleArray((u16, u16, u16)), }