Skip to content

Commit

Permalink
Support enum variants that have aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
dtolnay committed Oct 28, 2024
1 parent d1a1ff0 commit 97576f7
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 56 deletions.
32 changes: 24 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions serde-reflection/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ exclude = [
]

[dependencies]
thiserror = "1.0.25"
serde = { version = "1.0.126", features = ["derive"] }
erased-discriminant = "1"
once_cell = "1.7.2"
serde = { version = "1.0.126", features = ["derive"] }
thiserror = "1.0.25"

[dev-dependencies]
bincode = "1.3.3"
Expand Down
175 changes: 136 additions & 39 deletions serde-reflection/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
use crate::{
error::{Error, Result},
format::{ContainerFormat, ContainerFormatEntry, Format, FormatHolder, Named, VariantFormat},
trace::{Samples, Tracer},
trace::{EnumProgress, Samples, Tracer, VariantId},
value::IntoSeqDeserializer,
};
use serde::de::{self, DeserializeSeed, IntoDeserializer, Visitor};
use std::collections::BTreeMap;
use erased_discriminant::Discriminant;
use serde::de::{
self,
value::{BorrowedStrDeserializer, U32Deserializer},
DeserializeSeed, IntoDeserializer, Visitor,
};
use std::collections::btree_map::{BTreeMap, Entry};

/// Deserialize a single value.
/// * The lifetime 'a is set by the deserialization call site and the
Expand Down Expand Up @@ -391,55 +396,145 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> {

// Assumption: The first variant(s) should be "base cases", i.e. not cause infinite recursion
// while constructing sample values.
#[allow(clippy::map_entry)] // false positive https://github.com/rust-lang/rust-clippy/issues/9470
fn deserialize_enum<V>(
self,
name: &'static str,
enum_name: &'static str,
variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::TypeName(name.into()))?;
self.format.unify(Format::TypeName(enum_name.into()))?;
// Pre-update the registry.
self.tracer
.registry
.entry(name.to_string())
.entry(enum_name.to_string())
.unify(ContainerFormat::Enum(BTreeMap::new()))?;
let known_variants = match self.tracer.registry.get_mut(name) {
let known_variants = match self.tracer.registry.get_mut(enum_name) {
Some(ContainerFormat::Enum(x)) => x,
_ => unreachable!(),
};
// If we have found all the variants OR if the enum is marked as
// incomplete already, pick the first index.
let index = if known_variants.len() == variants.len()
|| self.tracer.incomplete_enums.contains(name)
{
0
} else {
let mut index = known_variants.len() as u32;
// Scan the range 0..=known_variants.len() downwards to find the next
// variant index to explore.
while known_variants.contains_key(&index) {
index -= 1;

// If the enum is marked as incomplete, just visit the first index
// because we presume it avoids recursion.
if self.tracer.incomplete_enums.contains_key(enum_name) {
return visitor.visit_enum(EnumDeserializer::new(
self.tracer,
self.samples,
VariantId::Index(0),
&mut VariantFormat::unknown(),
));
}

// First visit each of the variants by name according to `variants`.
// Later revisit them by u32 index until an index matching each of the
// named variants has been determined.
let provisional_min = u32::MAX - (variants.len() - 1) as u32;
for (i, &variant_name) in variants.iter().enumerate() {
if !self
.tracer
.discriminants
.contains_key(&(enum_name, VariantId::Name(variant_name)))
{
self.tracer
.incomplete_enums
.insert(enum_name.into(), EnumProgress::NamedVariantsRemaining);
// Insert into known_variants with a provisional index.
let provisional_index = provisional_min + i as u32;
let variant = known_variants
.entry(provisional_index)
.or_insert_with(|| Named {
name: variant_name.to_owned(),
value: VariantFormat::unknown(),
});
let mut value = variant.value.clone();
// Compute the discriminant and format for this variant.
let enum_value = visitor.visit_enum(EnumDeserializer::new(
self.tracer,
self.samples,
VariantId::Name(variant_name),
&mut value,
))?;
let discriminant = Discriminant::of(&enum_value);
self.tracer
.discriminants
.insert((enum_name, VariantId::Name(variant_name)), discriminant);
return Ok(enum_value);
}
}

// We know the discriminant for every variant name. Now visit them again
// by index to find the u32 id that goes with each name.
//
// If there are no provisional entries waiting for an index, just go
// with index 0.
let mut index = 0;
if known_variants.range(provisional_min..).next().is_some() {
self.tracer
.incomplete_enums
.insert(enum_name.into(), EnumProgress::IndexedVariantsRemaining);
while known_variants.contains_key(&index)
&& self
.tracer
.discriminants
.contains_key(&(enum_name, VariantId::Index(index)))
{
index += 1;
}
index
}

// Compute the discriminant and format for this variant.
let mut value = VariantFormat::unknown();
let enum_value = visitor.visit_enum(EnumDeserializer::new(
self.tracer,
self.samples,
VariantId::Index(index),
&mut value,
))?;
let discriminant = Discriminant::of(&enum_value);
self.tracer
.discriminants
.insert((enum_name, VariantId::Index(index)), discriminant.clone());
self.tracer.incomplete_enums.remove(enum_name);

// Rewrite provisional entries for which we now know a u32 index.
let known_variants = match self.tracer.registry.get_mut(enum_name) {
Some(ContainerFormat::Enum(x)) => x,
_ => unreachable!(),
};
let variant = known_variants.entry(index).or_insert_with(|| Named {
name: (*variants
.get(index as usize)
.expect("variant indexes must be a non-empty range 0..variants.len()"))
.to_string(),
value: VariantFormat::unknown(),
});
let mut value = variant.value.clone();
// Mark the enum as incomplete if this was not the last variant to explore.
if known_variants.len() != variants.len() {
self.tracer.incomplete_enums.insert(name.into());
for provisional_index in provisional_min..=u32::MAX {
if let Entry::Occupied(provisional_entry) = known_variants.entry(provisional_index) {
if self.tracer.discriminants
[&(enum_name, VariantId::Name(&provisional_entry.get().name))]
== discriminant
{
let provisional_entry = provisional_entry.remove();
match known_variants.entry(index) {
Entry::Vacant(vacant) => {
vacant.insert(provisional_entry);
}
Entry::Occupied(mut existing_entry) => {
// Discard the provisional entry's name and just
// keep the existing one.
existing_entry
.get_mut()
.value
.unify(provisional_entry.value)?;
}
}
} else {
self.tracer
.incomplete_enums
.insert(enum_name.into(), EnumProgress::IndexedVariantsRemaining);
}
}
}
if let Some(existing_entry) = known_variants.get_mut(&index) {
existing_entry.value.unify(value)?;
}
// Compute the format for this variant.
let inner = EnumDeserializer::new(self.tracer, self.samples, index, &mut value);
visitor.visit_enum(inner)
Ok(enum_value)
}

fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value>
Expand Down Expand Up @@ -539,21 +634,21 @@ where
struct EnumDeserializer<'de, 'a> {
tracer: &'a mut Tracer,
samples: &'de Samples,
index: u32,
variant_id: VariantId<'static>,
format: &'a mut VariantFormat,
}

impl<'de, 'a> EnumDeserializer<'de, 'a> {
fn new(
tracer: &'a mut Tracer,
samples: &'de Samples,
index: u32,
variant_id: VariantId<'static>,
format: &'a mut VariantFormat,
) -> Self {
Self {
tracer,
samples,
index,
variant_id,
format,
}
}
Expand All @@ -567,8 +662,10 @@ impl<'de, 'a> de::EnumAccess<'de> for EnumDeserializer<'de, 'a> {
where
V: DeserializeSeed<'de>,
{
let index = self.index;
let value = seed.deserialize(index.into_deserializer())?;
let value = match self.variant_id {
VariantId::Index(index) => seed.deserialize(U32Deserializer::new(index)),
VariantId::Name(name) => seed.deserialize(BorrowedStrDeserializer::new(name)),
}?;
Ok((value, self))
}
}
Expand Down
Loading

0 comments on commit 97576f7

Please sign in to comment.