diff --git a/src/de.rs b/src/de.rs index 08534a2..a04eacc 100644 --- a/src/de.rs +++ b/src/de.rs @@ -1,9 +1,138 @@ // Copyright (c) The Diem Core Contributors // SPDX-License-Identifier: Apache-2.0 +//! BCS deserialization + use crate::error::{Error, Result}; use serde::de::{self, Deserialize, DeserializeOwned, DeserializeSeed, IntoDeserializer, Visitor}; -use std::{convert::TryFrom, io::Read}; +use std::{convert::TryFrom, io::Read, marker::PhantomData}; + +/// Builder API to configure deserialization. +/// +/// # Examples +/// +/// ``` +/// use serde::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct Ip([u8; 4]); +/// +/// #[derive(Deserialize)] +/// struct Port(u16); +/// +/// #[derive(Deserialize)] +/// struct SocketAddr { +/// ip: Ip, +/// port: Port, +/// } +/// +/// let bytes = vec![0x7f, 0x00, 0x00, 0x01, 0x41, 0x1f]; +/// let socket_addr: SocketAddr = +/// bcs::de::Builder::new() +/// .max_sequence_length(1_024 * 1_024) +/// .max_container_depth(64) +/// .deserialize_bytes(&bytes) +/// .unwrap(); +/// +/// assert_eq!(socket_addr.ip.0, [127, 0, 0, 1]); +/// assert_eq!(socket_addr.port.0, 8001); +/// ``` +pub struct Builder { + max_container_depth: usize, + max_sequence_length: usize, + seed: T, +} + +impl Builder> { + /// Creates a `Builder` instance with default parameter values. + pub fn new() -> Self { + Self::with_seed(PhantomData) + } +} + +impl Builder { + /// Creates a `Builder` with the given seed value for stateful deserialization. + /// The other parameters are initialized with default values. + pub fn with_seed(seed: S) -> Builder { + Self { + max_container_depth: crate::MAX_CONTAINER_DEPTH, + max_sequence_length: crate::MAX_SEQUENCE_LENGTH, + seed, + } + } +} + +impl Default for Builder { + fn default() -> Self { + Builder::with_seed(Default::default()) + } +} + +impl Builder { + /// Sets the limit on depth of nested BCS data. + /// + /// The default is the [well-known limit][crate::MAX_CONTAINER_DEPTH] + /// defined for BCS. + /// If the value passed is larger than that, deserialization with this + /// `Builder` will fail with an error. + pub fn max_container_depth(mut self, limit: usize) -> Self { + self.max_container_depth = limit; + self + } + + /// Set the length limit on variable-length sequences: byte arrays, + /// strings, sequences and maps. Encountering an encoded sequence + /// with a greater length will cause deserialization to fail with an + /// `ExceededMaxLen` error. + /// + /// The default is the [well-known limit][crate::MAX_SEQUENCE_LENGTH] + /// defined for BCS. + /// If the value passed is larger than that, deserialization with this + /// `Builder` will fail with an error. + pub fn max_sequence_length(mut self, limit: usize) -> Self { + self.max_sequence_length = limit; + self + } + + fn check_sanity(&self) -> Result<(), Error> { + if self.max_container_depth > crate::MAX_CONTAINER_DEPTH { + return Err(Error::NotSupported( + "container depth limit exceeds the max allowed depth", + )); + } + if self.max_sequence_length > crate::MAX_SEQUENCE_LENGTH { + return Err(Error::NotSupported( + "sequence length limit exceeds the max sequence length", + )); + } + Ok(()) + } +} + +impl<'a, T> Builder +where + T: DeserializeSeed<'a>, +{ + /// Deserializes a value from an `&[u8]` using the configured parameters. + pub fn deserialize_bytes(self, bytes: &'a [u8]) -> Result { + self.check_sanity()?; + let mut deserializer = + Deserializer::new(bytes, self.max_container_depth, self.max_sequence_length); + let t = self.seed.deserialize(&mut deserializer)?; + deserializer.end()?; + Ok(t) + } + + /// Deserializes a value from an implementation of [`Read`] using the configured parameters. + pub fn deserialize_reader(self, reader: &'a mut impl Read) -> Result { + self.check_sanity()?; + let mut deserializer = + Deserializer::from_reader(reader, self.max_container_depth, self.max_sequence_length); + let t = self.seed.deserialize(&mut deserializer)?; + deserializer.end()?; + Ok(t) + } +} /// Deserializes a `&[u8]` into a type. /// @@ -38,10 +167,7 @@ pub fn from_bytes<'a, T>(bytes: &'a [u8]) -> Result where T: Deserialize<'a>, { - let mut deserializer = Deserializer::new(bytes, crate::MAX_CONTAINER_DEPTH); - let t = T::deserialize(&mut deserializer)?; - deserializer.end()?; - Ok(t) + Builder::new().deserialize_bytes(bytes) } /// Same as `from_bytes` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH` @@ -50,13 +176,9 @@ pub fn from_bytes_with_limit<'a, T>(bytes: &'a [u8], limit: usize) -> Result where T: Deserialize<'a>, { - if limit > crate::MAX_CONTAINER_DEPTH { - return Err(Error::NotSupported("limit exceeds the max allowed depth")); - } - let mut deserializer = Deserializer::new(bytes, limit); - let t = T::deserialize(&mut deserializer)?; - deserializer.end()?; - Ok(t) + Builder::new() + .max_container_depth(limit) + .deserialize_bytes(bytes) } /// Perform a stateful deserialization from a `&[u8]` using the provided `seed`. @@ -64,10 +186,7 @@ pub fn from_bytes_seed<'a, T>(seed: T, bytes: &'a [u8]) -> Result where T: DeserializeSeed<'a>, { - let mut deserializer = Deserializer::new(bytes, crate::MAX_CONTAINER_DEPTH); - let t = seed.deserialize(&mut deserializer)?; - deserializer.end()?; - Ok(t) + Builder::with_seed(seed).deserialize_bytes(bytes) } /// Same as `from_bytes_seed` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH` @@ -76,13 +195,9 @@ pub fn from_bytes_seed_with_limit<'a, T>(seed: T, bytes: &'a [u8], limit: usize) where T: DeserializeSeed<'a>, { - if limit > crate::MAX_CONTAINER_DEPTH { - return Err(Error::NotSupported("limit exceeds the max allowed depth")); - } - let mut deserializer = Deserializer::new(bytes, limit); - let t = seed.deserialize(&mut deserializer)?; - deserializer.end()?; - Ok(t) + Builder::with_seed(seed) + .max_container_depth(limit) + .deserialize_bytes(bytes) } /// Deserialize a type from an implementation of [`Read`]. @@ -90,10 +205,7 @@ pub fn from_reader(mut reader: impl Read) -> Result where T: DeserializeOwned, { - let mut deserializer = Deserializer::from_reader(&mut reader, crate::MAX_CONTAINER_DEPTH); - let t = T::deserialize(&mut deserializer)?; - deserializer.end()?; - Ok(t) + Builder::new().deserialize_reader(&mut reader) } /// Same as `from_reader_seed` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH` @@ -102,13 +214,9 @@ pub fn from_reader_with_limit(mut reader: impl Read, limit: usize) -> Result< where T: DeserializeOwned, { - if limit > crate::MAX_CONTAINER_DEPTH { - return Err(Error::NotSupported("limit exceeds the max allowed depth")); - } - let mut deserializer = Deserializer::from_reader(&mut reader, limit); - let t = T::deserialize(&mut deserializer)?; - deserializer.end()?; - Ok(t) + Builder::new() + .max_container_depth(limit) + .deserialize_reader(&mut reader) } /// Deserialize a type from an implementation of [`Read`] using the provided seed @@ -116,10 +224,7 @@ pub fn from_reader_seed(seed: T, mut reader: impl Read) -> Result where for<'a> T: DeserializeSeed<'a, Value = V>, { - let mut deserializer = Deserializer::from_reader(&mut reader, crate::MAX_CONTAINER_DEPTH); - let t = seed.deserialize(&mut deserializer)?; - deserializer.end()?; - Ok(t) + Builder::with_seed(seed).deserialize_reader(&mut reader) } /// Same as `from_reader_seed` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH` @@ -128,26 +233,28 @@ pub fn from_reader_seed_with_limit(seed: T, mut reader: impl Read, limit: where for<'a> T: DeserializeSeed<'a, Value = V>, { - if limit > crate::MAX_CONTAINER_DEPTH { - return Err(Error::NotSupported("limit exceeds the max allowed depth")); - } - let mut deserializer = Deserializer::from_reader(&mut reader, limit); - let t = seed.deserialize(&mut deserializer)?; - deserializer.end()?; - Ok(t) + Builder::with_seed(seed) + .max_container_depth(limit) + .deserialize_reader(&mut reader) } /// Deserialization implementation for BCS struct Deserializer { input: R, max_remaining_depth: usize, + max_sequence_length: usize, } impl<'de, R: Read> Deserializer> { - fn from_reader(input: &'de mut R, max_remaining_depth: usize) -> Self { + fn from_reader( + input: &'de mut R, + max_remaining_depth: usize, + max_sequence_length: usize, + ) -> Self { Deserializer { input: TeeReader::new(input), max_remaining_depth, + max_sequence_length, } } } @@ -155,10 +262,11 @@ impl<'de, R: Read> Deserializer> { impl<'de> Deserializer<&'de [u8]> { /// Creates a new `Deserializer` which will be deserializing the provided /// input. - fn new(input: &'de [u8], max_remaining_depth: usize) -> Self { + fn new(input: &'de [u8], max_remaining_depth: usize, max_sequence_length: usize) -> Self { Deserializer { input, max_remaining_depth, + max_sequence_length, } } } @@ -191,7 +299,11 @@ impl<'de, R: Read> Read for TeeReader<'de, R> { } } -trait BcsDeserializer<'de> { +trait ValidateLength { + fn validate_length(&self, parsed_value: u32) -> Result; +} + +trait BcsDeserializer<'de>: ValidateLength { type MaybeBorrowedBytes: AsRef<[u8]>; fn fill_slice(&mut self, slice: &mut [u8]) -> Result<()>; @@ -281,11 +393,8 @@ trait BcsDeserializer<'de> { } fn parse_length(&mut self) -> Result { - let len = self.parse_u32_from_uleb128()? as usize; - if len > crate::MAX_SEQUENCE_LENGTH { - return Err(Error::ExceededMaxLen(len)); - } - Ok(len) + let parsed_value = self.parse_u32_from_uleb128()?; + self.validate_length(parsed_value) } } @@ -303,6 +412,16 @@ impl<'de, R: Read> Deserializer> { } } +impl ValidateLength for Deserializer { + fn validate_length(&self, parsed_value: u32) -> Result { + let len = parsed_value as usize; + if len > self.max_sequence_length { + return Err(Error::ExceededMaxLen(len)); + } + Ok(len) + } +} + impl<'de, R: Read> BcsDeserializer<'de> for Deserializer> { type MaybeBorrowedBytes = Vec; diff --git a/src/lib.rs b/src/lib.rs index 4ede21a..4358db8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -303,7 +303,7 @@ //! # Ok(())} //! ``` -mod de; +pub mod de; mod error; mod ser; pub mod test_helpers; diff --git a/src/ser.rs b/src/ser.rs index 9bd42ce..86e8ebd 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -62,7 +62,9 @@ where T: ?Sized + Serialize, { if limit > crate::MAX_CONTAINER_DEPTH { - return Err(Error::NotSupported("limit exceeds the max allowed depth")); + return Err(Error::NotSupported( + "container depth limit exceeds the max allowed depth", + )); } let mut output = Vec::new(); serialize_into_with_limit(&mut output, value, limit)?; @@ -87,7 +89,9 @@ where T: ?Sized + Serialize, { if limit > crate::MAX_CONTAINER_DEPTH { - return Err(Error::NotSupported("limit exceeds the max allowed depth")); + return Err(Error::NotSupported( + "container depth limit exceeds the max allowed depth", + )); } let serializer = Serializer::new(write, limit); value.serialize(serializer) @@ -126,7 +130,9 @@ where T: ?Sized + Serialize, { if limit > crate::MAX_CONTAINER_DEPTH { - return Err(Error::NotSupported("limit exceeds the max allowed depth")); + return Err(Error::NotSupported( + "container depth limit exceeds the max allowed depth", + )); } let mut counter = WriteCounter(0); serialize_into_with_limit(&mut counter, value, limit)?; diff --git a/tests/serde.rs b/tests/serde.rs index 8305f61..b6dd51f 100644 --- a/tests/serde.rs +++ b/tests/serde.rs @@ -393,6 +393,19 @@ fn sequence_too_long() { } } +#[test] +fn custom_sequence_length_limit() { + let seq = vec![0i32; 10]; + let bytes = to_bytes(&seq).unwrap(); + let res: Result, _> = bcs::de::Builder::new() + .max_sequence_length(9) + .deserialize_bytes(&bytes); + match res.unwrap_err() { + Error::ExceededMaxLen(len) => assert_eq!(len, 10), + _ => panic!(), + } +} + #[test] fn variable_lengths() { assert_eq!(to_bytes(&vec![(); 1]).unwrap(), vec![0x01]); @@ -767,7 +780,8 @@ fn test_recursion_limit() { // test customized limit let limit = 100; - let not_supported_err = Error::NotSupported("limit exceeds the max allowed depth"); + let container_depth_limit_not_supported_err = + Error::NotSupported("container depth limit exceeds the max allowed depth"); let l4 = List::integers(limit); assert_eq!( to_bytes_with_limit(&l4, limit), @@ -775,7 +789,7 @@ fn test_recursion_limit() { ); assert_eq!( to_bytes_with_limit(&l4, MAX_CONTAINER_DEPTH + 1), - Err(not_supported_err.clone()), + Err(container_depth_limit_not_supported_err.clone()), ); let bytes = to_bytes_with_limit(&l4, limit + 1).unwrap(); assert_eq!( @@ -785,7 +799,7 @@ fn test_recursion_limit() { assert_eq!(from_bytes_with_limit(&bytes, limit + 1), Ok(l4)); assert_eq!( from_bytes_with_limit::>(&bytes, MAX_CONTAINER_DEPTH + 1), - Err(not_supported_err) + Err(container_depth_limit_not_supported_err) ); }