diff --git a/serde-reflection/README.md b/serde-reflection/README.md index 21e737285..ac15c588b 100644 --- a/serde-reflection/README.md +++ b/serde-reflection/README.md @@ -105,6 +105,13 @@ use the crate [`serde-name`](https://crates.io/crates/serde-name) and its adapte terminate. (Work around: re-order the variants. For instance `enum List { Some(Box), None}` must be rewritten `enum List { None, Some(Box)}`.) +* Certain standard types such as `std::num::NonZeroU8` may not be tracked as a +container and appear simply as their underlying primitive type (e.g. `u8`) in the +formats. This loss of information makes it difficult to use `trace_value` to work +around deserialization invariants (see example below). As a work around, you may +override the default for the primitive type using `TracerConfig` (e.g. `let config = +TracerConfig::default().default_u8_value(1);`). + ### Security CAVEAT At this time, `HashSet` and `BTreeSet` are treated as sequences (i.e. vectors) diff --git a/serde-reflection/src/de.rs b/serde-reflection/src/de.rs index d85a19826..9ee16213a 100644 --- a/serde-reflection/src/de.rs +++ b/serde-reflection/src/de.rs @@ -50,7 +50,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::Bool)?; - visitor.visit_bool(false) + visitor.visit_bool(self.tracer.config.default_bool_value) } fn deserialize_i8(self, visitor: V) -> Result @@ -58,7 +58,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::I8)?; - visitor.visit_i8(0) + visitor.visit_i8(self.tracer.config.default_i8_value) } fn deserialize_i16(self, visitor: V) -> Result @@ -66,7 +66,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::I16)?; - visitor.visit_i16(0) + visitor.visit_i16(self.tracer.config.default_i16_value) } fn deserialize_i32(self, visitor: V) -> Result @@ -74,7 +74,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::I32)?; - visitor.visit_i32(0) + visitor.visit_i32(self.tracer.config.default_i32_value) } fn deserialize_i64(self, visitor: V) -> Result @@ -82,7 +82,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::I64)?; - visitor.visit_i64(0) + visitor.visit_i64(self.tracer.config.default_i64_value) } fn deserialize_i128(self, visitor: V) -> Result @@ -90,7 +90,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::I128)?; - visitor.visit_i128(0) + visitor.visit_i128(self.tracer.config.default_i128_value) } fn deserialize_u8(self, visitor: V) -> Result @@ -98,7 +98,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::U8)?; - visitor.visit_u8(0) + visitor.visit_u8(self.tracer.config.default_u8_value) } fn deserialize_u16(self, visitor: V) -> Result @@ -106,7 +106,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::U16)?; - visitor.visit_u16(0) + visitor.visit_u16(self.tracer.config.default_u16_value) } fn deserialize_u32(self, visitor: V) -> Result @@ -114,7 +114,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::U32)?; - visitor.visit_u32(0) + visitor.visit_u32(self.tracer.config.default_u32_value) } fn deserialize_u64(self, visitor: V) -> Result @@ -122,7 +122,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::U64)?; - visitor.visit_u64(0) + visitor.visit_u64(self.tracer.config.default_u64_value) } fn deserialize_u128(self, visitor: V) -> Result @@ -130,7 +130,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::U128)?; - visitor.visit_u128(0) + visitor.visit_u128(self.tracer.config.default_u128_value) } fn deserialize_f32(self, visitor: V) -> Result @@ -138,7 +138,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::F32)?; - visitor.visit_f32(0.0) + visitor.visit_f32(self.tracer.config.default_f32_value) } fn deserialize_f64(self, visitor: V) -> Result @@ -146,7 +146,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::F64)?; - visitor.visit_f64(0.0) + visitor.visit_f64(self.tracer.config.default_f64_value) } fn deserialize_char(self, visitor: V) -> Result @@ -154,7 +154,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::Char)?; - visitor.visit_char('A') + visitor.visit_char(self.tracer.config.default_char_value) } fn deserialize_str(self, visitor: V) -> Result @@ -162,7 +162,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::Str)?; - visitor.visit_borrowed_str("") + visitor.visit_borrowed_str(self.tracer.config.default_borrowed_str_value) } fn deserialize_string(self, visitor: V) -> Result @@ -170,7 +170,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::Str)?; - visitor.visit_string(String::new()) + visitor.visit_string(self.tracer.config.default_string_value.clone()) } fn deserialize_bytes(self, visitor: V) -> Result @@ -178,7 +178,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::Bytes)?; - visitor.visit_borrowed_bytes(b"") + visitor.visit_borrowed_bytes(self.tracer.config.default_borrowed_bytes_value) } fn deserialize_byte_buf(self, visitor: V) -> Result @@ -186,7 +186,7 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { V: Visitor<'de>, { self.format.unify(Format::Bytes)?; - visitor.visit_byte_buf(Vec::new()) + visitor.visit_byte_buf(self.tracer.config.default_byte_buf_value.clone()) } fn deserialize_option(self, visitor: V) -> Result diff --git a/serde-reflection/src/lib.rs b/serde-reflection/src/lib.rs index 861796ea1..9fc0895f0 100644 --- a/serde-reflection/src/lib.rs +++ b/serde-reflection/src/lib.rs @@ -108,6 +108,13 @@ //! terminate. (Work around: re-order the variants. For instance `enum List { //! Some(Box), None}` must be rewritten `enum List { None, Some(Box)}`.) //! +//! * Certain standard types such as `std::num::NonZeroU8` may not be tracked as a +//! container and appear simply as their underlying primitive type (e.g. `u8`) in the +//! formats. This loss of information makes it difficult to use `trace_value` to work +//! around deserialization invariants (see example below). As a work around, you may +//! override the default for the primitive type using `TracerConfig` (e.g. `let config = +//! TracerConfig::default().default_u8_value(1);`). +//! //! ## Security CAVEAT //! //! At this time, `HashSet` and `BTreeSet` are treated as sequences (i.e. vectors) diff --git a/serde-reflection/src/trace.rs b/serde-reflection/src/trace.rs index 8a82d023d..fe3347616 100644 --- a/serde-reflection/src/trace.rs +++ b/serde-reflection/src/trace.rs @@ -57,6 +57,24 @@ pub struct TracerConfig { pub(crate) record_samples_for_newtype_structs: bool, pub(crate) record_samples_for_tuple_structs: bool, pub(crate) record_samples_for_structs: bool, + pub(crate) default_bool_value: bool, + pub(crate) default_u8_value: u8, + pub(crate) default_u16_value: u16, + pub(crate) default_u32_value: u32, + pub(crate) default_u64_value: u64, + pub(crate) default_u128_value: u128, + pub(crate) default_i8_value: i8, + pub(crate) default_i16_value: i16, + pub(crate) default_i32_value: i32, + pub(crate) default_i64_value: i64, + pub(crate) default_i128_value: i128, + pub(crate) default_f32_value: f32, + pub(crate) default_f64_value: f64, + pub(crate) default_char_value: char, + pub(crate) default_borrowed_str_value: &'static str, + pub(crate) default_string_value: String, + pub(crate) default_borrowed_bytes_value: &'static [u8], + pub(crate) default_byte_buf_value: Vec, } impl Default for TracerConfig { @@ -67,10 +85,38 @@ impl Default for TracerConfig { record_samples_for_newtype_structs: true, record_samples_for_tuple_structs: false, record_samples_for_structs: false, + default_bool_value: false, + default_u8_value: 0, + default_u16_value: 0, + default_u32_value: 0, + default_u64_value: 0, + default_u128_value: 0, + default_i8_value: 0, + default_i16_value: 0, + default_i32_value: 0, + default_i64_value: 0, + default_i128_value: 0, + default_f32_value: 0.0, + default_f64_value: 0.0, + default_char_value: 'A', + default_borrowed_str_value: "", + default_string_value: String::new(), + default_borrowed_bytes_value: b"", + default_byte_buf_value: Vec::new(), } } } +macro_rules! define_default_value_setter { + ($method:ident, $ty:ty) => { + /// The default serialized value for this primitive type. + pub fn $method(mut self, value: $ty) -> Self { + self.$method = value; + self + } + }; +} + impl TracerConfig { /// Whether to trace the human readable encoding of (de)serialization. #[allow(clippy::wrong_self_convention)] @@ -96,6 +142,25 @@ impl TracerConfig { self.record_samples_for_structs = value; self } + + define_default_value_setter!(default_bool_value, bool); + define_default_value_setter!(default_u8_value, u8); + define_default_value_setter!(default_u16_value, u16); + define_default_value_setter!(default_u32_value, u32); + define_default_value_setter!(default_u64_value, u64); + define_default_value_setter!(default_u128_value, u128); + define_default_value_setter!(default_i8_value, i8); + define_default_value_setter!(default_i16_value, i16); + define_default_value_setter!(default_i32_value, i32); + define_default_value_setter!(default_i64_value, i64); + define_default_value_setter!(default_i128_value, i128); + define_default_value_setter!(default_f32_value, f32); + define_default_value_setter!(default_f64_value, f64); + define_default_value_setter!(default_char_value, char); + define_default_value_setter!(default_borrowed_str_value, &'static str); + define_default_value_setter!(default_string_value, String); + define_default_value_setter!(default_borrowed_bytes_value, &'static [u8]); + define_default_value_setter!(default_byte_buf_value, Vec); } impl Tracer { diff --git a/serde-reflection/tests/serde.rs b/serde-reflection/tests/serde.rs index 811d5f69f..c31f1b405 100644 --- a/serde-reflection/tests/serde.rs +++ b/serde-reflection/tests/serde.rs @@ -457,3 +457,48 @@ fn test_repeated_tracing() { )))))) ); } + +#[test] +fn test_default_value_for_primitive_types() { + let config = TracerConfig::default() + .default_u8_value(1) + .default_u16_value(2) + .default_u32_value(3) + .default_u64_value(4) + .default_u128_value(5) + .default_i8_value(6) + .default_i16_value(7) + .default_i32_value(8) + .default_i64_value(9) + .default_i128_value(10) + .default_string_value("A string".into()) + .default_borrowed_str_value("A borrowed str"); + let mut tracer = Tracer::new(config); + let samples = Samples::new(); + + let (format, value) = tracer + .trace_type_once::(&samples) + .unwrap(); + assert_eq!(format, Format::U8); // Not a container + assert_eq!(value.get(), 1); + + let (format, value) = tracer.trace_type_once::(&samples).unwrap(); + assert_eq!(format, Format::U8); + assert_eq!(value, 1); + + let (format, value) = tracer.trace_type_once::(&samples).unwrap(); + assert_eq!(format, Format::U16); + assert_eq!(value, 2); + + let (format, value) = tracer.trace_type_once::(&samples).unwrap(); + assert_eq!(format, Format::I128); + assert_eq!(value, 10); + + let (format, value) = tracer.trace_type_once::(&samples).unwrap(); + assert_eq!(format, Format::Str); + assert_eq!(value.as_str(), "A string"); + + let (format, value) = tracer.trace_type_once::<&str>(&samples).unwrap(); + assert_eq!(format, Format::Str); + assert_eq!(value, "A borrowed str"); +}