Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow configuring the default value for primitives types #39

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions serde-reflection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ use the crate [`serde-name`](https://crates.io/crates/serde-name) and its adapte
terminate. (Work around: re-order the variants. For instance `enum List {
Some(Box<List>), None}` must be rewritten `enum List { None, Some(Box<List>)}`.)

* Certain standard types such as `std::num::NonZeroU8` may not be tracked as a
container and appear simply as their underlying primitive type (e.g. `u8`) in the
formats. This loss of information makes it difficult to use `trace_value` to work
around deserialization invariants (see example below). As a work around, you may
override the default for the primitive type using `TracerConfig` (e.g. `let config =
TracerConfig::default().default_u8_value(1);`).

### Security CAVEAT

At this time, `HashSet<T>` and `BTreeSet<T>` are treated as sequences (i.e. vectors)
Expand Down
36 changes: 18 additions & 18 deletions serde-reflection/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,143 +50,143 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> {
V: Visitor<'de>,
{
self.format.unify(Format::Bool)?;
visitor.visit_bool(false)
visitor.visit_bool(self.tracer.config.default_bool_value)
}

fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::I8)?;
visitor.visit_i8(0)
visitor.visit_i8(self.tracer.config.default_i8_value)
}

fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::I16)?;
visitor.visit_i16(0)
visitor.visit_i16(self.tracer.config.default_i16_value)
}

fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::I32)?;
visitor.visit_i32(0)
visitor.visit_i32(self.tracer.config.default_i32_value)
}

fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::I64)?;
visitor.visit_i64(0)
visitor.visit_i64(self.tracer.config.default_i64_value)
}

fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::I128)?;
visitor.visit_i128(0)
visitor.visit_i128(self.tracer.config.default_i128_value)
}

fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::U8)?;
visitor.visit_u8(0)
visitor.visit_u8(self.tracer.config.default_u8_value)
}

fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::U16)?;
visitor.visit_u16(0)
visitor.visit_u16(self.tracer.config.default_u16_value)
}

fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::U32)?;
visitor.visit_u32(0)
visitor.visit_u32(self.tracer.config.default_u32_value)
}

fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::U64)?;
visitor.visit_u64(0)
visitor.visit_u64(self.tracer.config.default_u64_value)
}

fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::U128)?;
visitor.visit_u128(0)
visitor.visit_u128(self.tracer.config.default_u128_value)
}

fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::F32)?;
visitor.visit_f32(0.0)
visitor.visit_f32(self.tracer.config.default_f32_value)
}

fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::F64)?;
visitor.visit_f64(0.0)
visitor.visit_f64(self.tracer.config.default_f64_value)
}

fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::Char)?;
visitor.visit_char('A')
visitor.visit_char(self.tracer.config.default_char_value)
}

fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::Str)?;
visitor.visit_borrowed_str("")
visitor.visit_borrowed_str(self.tracer.config.default_borrowed_str_value)
}

fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::Str)?;
visitor.visit_string(String::new())
visitor.visit_string(self.tracer.config.default_string_value.clone())
}

fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::Bytes)?;
visitor.visit_borrowed_bytes(b"")
visitor.visit_borrowed_bytes(self.tracer.config.default_borrowed_bytes_value)
}

fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.format.unify(Format::Bytes)?;
visitor.visit_byte_buf(Vec::new())
visitor.visit_byte_buf(self.tracer.config.default_byte_buf_value.clone())
}

fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
Expand Down
7 changes: 7 additions & 0 deletions serde-reflection/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@
//! terminate. (Work around: re-order the variants. For instance `enum List {
//! Some(Box<List>), None}` must be rewritten `enum List { None, Some(Box<List>)}`.)
//!
//! * Certain standard types such as `std::num::NonZeroU8` may not be tracked as a
//! container and appear simply as their underlying primitive type (e.g. `u8`) in the
//! formats. This loss of information makes it difficult to use `trace_value` to work
//! around deserialization invariants (see example below). As a work around, you may
//! override the default for the primitive type using `TracerConfig` (e.g. `let config =
//! TracerConfig::default().default_u8_value(1);`).
//!
//! ## Security CAVEAT
//!
//! At this time, `HashSet<T>` and `BTreeSet<T>` are treated as sequences (i.e. vectors)
Expand Down
65 changes: 65 additions & 0 deletions serde-reflection/src/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,24 @@ pub struct TracerConfig {
pub(crate) record_samples_for_newtype_structs: bool,
pub(crate) record_samples_for_tuple_structs: bool,
pub(crate) record_samples_for_structs: bool,
pub(crate) default_bool_value: bool,
pub(crate) default_u8_value: u8,
pub(crate) default_u16_value: u16,
pub(crate) default_u32_value: u32,
pub(crate) default_u64_value: u64,
pub(crate) default_u128_value: u128,
pub(crate) default_i8_value: i8,
pub(crate) default_i16_value: i16,
pub(crate) default_i32_value: i32,
pub(crate) default_i64_value: i64,
pub(crate) default_i128_value: i128,
pub(crate) default_f32_value: f32,
pub(crate) default_f64_value: f64,
pub(crate) default_char_value: char,
pub(crate) default_borrowed_str_value: &'static str,
pub(crate) default_string_value: String,
pub(crate) default_borrowed_bytes_value: &'static [u8],
pub(crate) default_byte_buf_value: Vec<u8>,
}

impl Default for TracerConfig {
Expand All @@ -67,10 +85,38 @@ impl Default for TracerConfig {
record_samples_for_newtype_structs: true,
record_samples_for_tuple_structs: false,
record_samples_for_structs: false,
default_bool_value: false,
default_u8_value: 0,
default_u16_value: 0,
default_u32_value: 0,
default_u64_value: 0,
default_u128_value: 0,
default_i8_value: 0,
default_i16_value: 0,
default_i32_value: 0,
default_i64_value: 0,
default_i128_value: 0,
default_f32_value: 0.0,
default_f64_value: 0.0,
default_char_value: 'A',
default_borrowed_str_value: "",
default_string_value: String::new(),
default_borrowed_bytes_value: b"",
default_byte_buf_value: Vec::new(),
}
}
}

macro_rules! define_default_value_setter {
($method:ident, $ty:ty) => {
/// The default serialized value for this primitive type.
pub fn $method(mut self, value: $ty) -> Self {
self.$method = value;
self
}
};
}

impl TracerConfig {
/// Whether to trace the human readable encoding of (de)serialization.
#[allow(clippy::wrong_self_convention)]
Expand All @@ -96,6 +142,25 @@ impl TracerConfig {
self.record_samples_for_structs = value;
self
}

define_default_value_setter!(default_bool_value, bool);
define_default_value_setter!(default_u8_value, u8);
define_default_value_setter!(default_u16_value, u16);
define_default_value_setter!(default_u32_value, u32);
define_default_value_setter!(default_u64_value, u64);
define_default_value_setter!(default_u128_value, u128);
define_default_value_setter!(default_i8_value, i8);
define_default_value_setter!(default_i16_value, i16);
define_default_value_setter!(default_i32_value, i32);
define_default_value_setter!(default_i64_value, i64);
define_default_value_setter!(default_i128_value, i128);
define_default_value_setter!(default_f32_value, f32);
define_default_value_setter!(default_f64_value, f64);
define_default_value_setter!(default_char_value, char);
define_default_value_setter!(default_borrowed_str_value, &'static str);
define_default_value_setter!(default_string_value, String);
define_default_value_setter!(default_borrowed_bytes_value, &'static [u8]);
define_default_value_setter!(default_byte_buf_value, Vec<u8>);
}

impl Tracer {
Expand Down
45 changes: 45 additions & 0 deletions serde-reflection/tests/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,48 @@ fn test_repeated_tracing() {
))))))
);
}

#[test]
fn test_default_value_for_primitive_types() {
let config = TracerConfig::default()
.default_u8_value(1)
.default_u16_value(2)
.default_u32_value(3)
.default_u64_value(4)
.default_u128_value(5)
.default_i8_value(6)
.default_i16_value(7)
.default_i32_value(8)
.default_i64_value(9)
.default_i128_value(10)
.default_string_value("A string".into())
.default_borrowed_str_value("A borrowed str");
let mut tracer = Tracer::new(config);
let samples = Samples::new();

let (format, value) = tracer
.trace_type_once::<std::num::NonZeroU8>(&samples)
.unwrap();
assert_eq!(format, Format::U8); // Not a container
assert_eq!(value.get(), 1);

let (format, value) = tracer.trace_type_once::<u8>(&samples).unwrap();
assert_eq!(format, Format::U8);
assert_eq!(value, 1);

let (format, value) = tracer.trace_type_once::<u16>(&samples).unwrap();
assert_eq!(format, Format::U16);
assert_eq!(value, 2);

let (format, value) = tracer.trace_type_once::<i128>(&samples).unwrap();
assert_eq!(format, Format::I128);
assert_eq!(value, 10);

let (format, value) = tracer.trace_type_once::<String>(&samples).unwrap();
assert_eq!(format, Format::Str);
assert_eq!(value.as_str(), "A string");

let (format, value) = tracer.trace_type_once::<&str>(&samples).unwrap();
assert_eq!(format, Format::Str);
assert_eq!(value, "A borrowed str");
}
Loading