Skip to content

Commit

Permalink
fix deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
kralverde committed Feb 16, 2025
1 parent 01c73ee commit 4bbb46d
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 97 deletions.
42 changes: 36 additions & 6 deletions pumpkin-nbt/src/compound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,40 @@ impl NbtCompound {
}
}

pub fn skip_content<R>(reader: &mut ReadAdaptor<R>) -> Result<(), Error>
where
R: Read,
{
loop {
let tag_id = match reader.get_u8_be() {
Ok(id) => id,
Err(err) => match err {
Error::Incomplete(err) => match err.kind() {
ErrorKind::UnexpectedEof => {
break;
}
_ => {
return Err(Error::Incomplete(err));
}
},
_ => {
return Err(err);
}
},
};
if tag_id == END_ID {
break;
}

let len = reader.get_u16_be()?;
reader.skip_bytes(len as u64)?;

NbtTag::skip_data(reader, tag_id)?;
}

Ok(())
}

pub fn deserialize_content<R>(reader: &mut ReadAdaptor<R>) -> Result<NbtCompound, Error>
where
R: Read,
Expand Down Expand Up @@ -45,12 +79,8 @@ impl NbtCompound {
}

let name = get_nbt_string(reader)?;

if let Ok(tag) = NbtTag::deserialize_data(reader, tag_id) {
compound.put(&name, tag);
} else {
break;
}
let tag = NbtTag::deserialize_data(reader, tag_id)?;
compound.put(&name, tag);
}

Ok(compound)
Expand Down
163 changes: 112 additions & 51 deletions pumpkin-nbt/src/deserializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ impl<R: Read> ReadAdaptor<R> {
}

impl<R: Read> ReadAdaptor<R> {
pub fn skip_bytes(&mut self, count: u64) -> Result<()> {
let _ = io::copy(&mut self.reader.by_ref().take(count), &mut io::sink())
.map_err(Error::Incomplete)?;
Ok(())
}

//TODO: Macroize this
pub fn get_u8_be(&mut self) -> Result<u8> {
let mut buf = [0u8];
Expand Down Expand Up @@ -103,7 +109,7 @@ impl<R: Read> ReadAdaptor<R> {
#[derive(Debug)]
pub struct Deserializer<R: Read> {
input: ReadAdaptor<R>,
tag_to_deserialize: Option<u8>,
tag_to_deserialize_stack: Vec<u8>,
// Yes, this breaks with recursion. Just an attempt at a sanity check
in_list: bool,
is_named: bool,
Expand All @@ -113,7 +119,7 @@ impl<R: Read> Deserializer<R> {
pub fn new(input: R, is_named: bool) -> Self {
Deserializer {
input: ReadAdaptor { reader: input },
tag_to_deserialize: None,
tag_to_deserialize_stack: Vec::new(),
in_list: false,
is_named,
}
Expand Down Expand Up @@ -143,54 +149,71 @@ impl<'de, R: Read> de::Deserializer<'de> for &mut Deserializer<R> {

forward_to_deserialize_any! {
i8 i16 i32 i64 f32 f64 char str string unit unit_struct seq tuple tuple_struct
ignored_any bytes newtype_struct byte_buf
bytes newtype_struct byte_buf
}

fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let Some(tag) = self.tag_to_deserialize_stack.pop() else {
return Err(Error::SerdeError("Ignoring nothing!".to_string()));
};

NbtTag::skip_data(&mut self.input, tag)?;
visitor.visit_unit()
}

fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let tag_to_deserialize = self.tag_to_deserialize.unwrap();

let list_type = match tag_to_deserialize {
LIST_ID => Some(self.input.get_u8_be()?),
INT_ARRAY_ID => Some(INT_ID),
LONG_ARRAY_ID => Some(LONG_ID),
BYTE_ARRAY_ID => Some(BYTE_ID),
_ => None,
let Some(tag_to_deserialize) = self.tag_to_deserialize_stack.pop() else {
return Err(Error::SerdeError(
"The top level must be a component (e.g. a struct)".to_string(),
));
};

if let Some(list_type) = list_type {
let remaining_values = self.input.get_i32_be()?;
if remaining_values < 0 {
return Err(Error::NegativeLength(remaining_values));
match tag_to_deserialize {
END_ID => Err(Error::SerdeError(
"Trying to deserialize an END tag!".to_string(),
)),
LIST_ID | INT_ARRAY_ID | LONG_ARRAY_ID | BYTE_ARRAY_ID => {
let list_type = match tag_to_deserialize {
LIST_ID => self.input.get_u8_be()?,
INT_ARRAY_ID => INT_ID,
LONG_ARRAY_ID => LONG_ID,
BYTE_ARRAY_ID => BYTE_ID,
_ => unreachable!(),
};

let remaining_values = self.input.get_i32_be()?;
if remaining_values < 0 {
return Err(Error::NegativeLength(remaining_values));
}

let result = visitor.visit_seq(ListAccess {
de: self,
list_type,
remaining_values: remaining_values as usize,
})?;
Ok(result)
}
COMPOUND_ID => visitor.visit_map(CompoundAccess { de: self }),
_ => {
let result = match NbtTag::deserialize_data(&mut self.input, tag_to_deserialize)? {
NbtTag::Byte(value) => visitor.visit_i8::<Error>(value)?,
NbtTag::Short(value) => visitor.visit_i16::<Error>(value)?,
NbtTag::Int(value) => visitor.visit_i32::<Error>(value)?,
NbtTag::Long(value) => visitor.visit_i64::<Error>(value)?,
NbtTag::Float(value) => visitor.visit_f32::<Error>(value)?,
NbtTag::Double(value) => visitor.visit_f64::<Error>(value)?,
NbtTag::String(value) => visitor.visit_string::<Error>(value)?,
_ => unreachable!(),
};
Ok(result)
}

return visitor.visit_seq(ListAccess {
de: self,
list_type,
remaining_values: remaining_values as usize,
});
}

// TODO: Just skip values for the ignored values so we dont do the work of
// parsing/allocating space for the NBT representations
let result: Result<V::Value> = Ok(
match NbtTag::deserialize_data(&mut self.input, tag_to_deserialize)? {
NbtTag::Byte(value) => visitor.visit_i8::<Error>(value)?,
NbtTag::Short(value) => visitor.visit_i16::<Error>(value)?,
NbtTag::Int(value) => visitor.visit_i32::<Error>(value)?,
NbtTag::Long(value) => visitor.visit_i64::<Error>(value)?,
NbtTag::Float(value) => visitor.visit_f32::<Error>(value)?,
NbtTag::Double(value) => visitor.visit_f64::<Error>(value)?,
NbtTag::String(value) => visitor.visit_string::<Error>(value)?,
// If we get to this point, we dont actually need the data (its omitted from the
// struct we're deserializing). Just return None.
_ => visitor.visit_none::<Error>()?,
},
);
self.tag_to_deserialize = None;
result
}

fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
Expand Down Expand Up @@ -238,15 +261,25 @@ impl<'de, R: Read> de::Deserializer<'de> for &mut Deserializer<R> {
where
V: Visitor<'de>,
{
if self.tag_to_deserialize.unwrap() == BYTE_ID {
let value = self.input.get_u8_be()?;
if value != 0 {
visitor.visit_bool(true)
if let Some(tag_id) = self.tag_to_deserialize_stack.last() {
if *tag_id == BYTE_ID {
let value = self.input.get_u8_be()?;
if value != 0 {
visitor.visit_bool(true)
} else {
visitor.visit_bool(false)
}
} else {
visitor.visit_bool(false)
Err(Error::UnsupportedType(format!(
"Non-byte bool (found type {})",
tag_id
)))
}
} else {
Err(Error::UnsupportedType("Non-byte bool".to_string()))
Err(Error::SerdeError(
"Wanted to deserialize a bool, but there was no type hint on the stack!"
.to_string(),
))
}
}

Expand All @@ -263,7 +296,7 @@ impl<'de, R: Read> de::Deserializer<'de> for &mut Deserializer<R> {
visitor.visit_enum(variant.into_deserializer())
}

fn deserialize_option<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
Expand All @@ -275,7 +308,13 @@ impl<'de, R: Read> de::Deserializer<'de> for &mut Deserializer<R> {
where
V: Visitor<'de>,
{
if self.tag_to_deserialize.is_none() {
if let Some(tag_id) = self.tag_to_deserialize_stack.last() {
if *tag_id != COMPOUND_ID {
return Err(Error::SerdeError(
"Trying to deserialize a map without a compound ID".to_string(),
));
}
} else {
let next_byte = self.input.get_u8_be()?;
if next_byte != COMPOUND_ID {
return Err(Error::NoRootCompound(next_byte));
Expand Down Expand Up @@ -328,13 +367,13 @@ impl<'de, R: Read> MapAccess<'de> for CompoundAccess<'_, R> {
K: DeserializeSeed<'de>,
{
let tag = self.de.input.get_u8_be()?;
self.de.tag_to_deserialize = Some(tag);
self.de.tag_to_deserialize_stack.push(tag);

if tag == END_ID {
return Ok(None);
}

seed.deserialize(&mut *self.de).map(Some)
seed.deserialize(MapKey { de: self.de }).map(Some)
}

fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
Expand All @@ -345,6 +384,27 @@ impl<'de, R: Read> MapAccess<'de> for CompoundAccess<'_, R> {
}
}

struct MapKey<'a, R: Read> {
de: &'a mut Deserializer<R>,
}

impl<'de, R: Read> de::Deserializer<'de> for MapKey<'_, R> {
type Error = Error;

fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
let key = get_nbt_string(&mut self.de.input)?;
visitor.visit_string(key)
}

forward_to_deserialize_any! {
bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit unit_struct seq tuple tuple_struct map
struct identifier ignored_any bytes enum newtype_struct byte_buf option
}
}

struct ListAccess<'a, R: Read> {
de: &'a mut Deserializer<R>,
remaining_values: usize,
Expand All @@ -367,10 +427,11 @@ impl<'de, R: Read> SeqAccess<'de> for ListAccess<'_, R> {
}

self.remaining_values -= 1;
self.de.tag_to_deserialize = Some(self.list_type);
self.de.tag_to_deserialize_stack.push(self.list_type);
self.de.in_list = true;
let result = seed.deserialize(&mut *self.de).map(Some);
self.de.in_list = false;

result
}
}
Loading

0 comments on commit 4bbb46d

Please sign in to comment.