diff --git a/pumpkin-nbt/src/compound.rs b/pumpkin-nbt/src/compound.rs index 5fa87d4bc..0d0cbe4d4 100644 --- a/pumpkin-nbt/src/compound.rs +++ b/pumpkin-nbt/src/compound.rs @@ -1,7 +1,8 @@ +use crate::deserializer::ReadAdaptor; +use crate::serializer::WriteAdaptor; use crate::tag::NbtTag; use crate::{get_nbt_string, Error, Nbt, END_ID}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use std::io::{Cursor, Write}; +use std::io::{ErrorKind, Read, Write}; use std::vec::IntoIter; #[derive(Clone, Debug, Default, PartialEq, PartialOrd)] @@ -16,46 +17,85 @@ impl NbtCompound { } } - pub fn deserialize_content(bytes: &mut impl Buf) -> Result { - let mut compound = NbtCompound::new(); - - while bytes.has_remaining() { - let tag_id = bytes.get_u8(); + pub fn skip_content(reader: &mut ReadAdaptor) -> Result<(), Error> + where + R: Read, + { + loop { + let tag_id = match reader.get_u8_be() { + Ok(id) => id, + Err(err) => match err { + Error::Incomplete(err) => match err.kind() { + ErrorKind::UnexpectedEof => { + break; + } + _ => { + return Err(Error::Incomplete(err)); + } + }, + _ => { + return Err(err); + } + }, + }; if tag_id == END_ID { break; } - let name = get_nbt_string(bytes).map_err(|_| Error::Cesu8DecodingError)?; + let len = reader.get_u16_be()?; + reader.skip_bytes(len as u64)?; + + NbtTag::skip_data(reader, tag_id)?; + } + + Ok(()) + } - if let Ok(tag) = NbtTag::deserialize_data(bytes, tag_id) { - compound.put(&name, tag); - } else { + pub fn deserialize_content(reader: &mut ReadAdaptor) -> Result + where + R: Read, + { + let mut compound = NbtCompound::new(); + + loop { + let tag_id = match reader.get_u8_be() { + Ok(id) => id, + Err(err) => match err { + Error::Incomplete(err) => match err.kind() { + ErrorKind::UnexpectedEof => { + break; + } + _ => { + return Err(Error::Incomplete(err)); + } + }, + _ => { + return Err(err); + } + }, + }; + if tag_id == END_ID { break; } + + let name = get_nbt_string(reader)?; + let tag = NbtTag::deserialize_data(reader, tag_id)?; + compound.put(&name, tag); } Ok(compound) } - pub fn deserialize_content_from_cursor( - cursor: &mut Cursor<&[u8]>, - ) -> Result { - Self::deserialize_content(cursor) - } - - pub fn serialize_content(&self) -> Bytes { - let mut bytes = BytesMut::new(); + pub fn serialize_content(&self, w: &mut WriteAdaptor) -> Result<(), Error> + where + W: Write, + { for (name, tag) in &self.child_tags { - bytes.put_u8(tag.get_type_id()); - bytes.put(NbtTag::String(name.clone()).serialize_data()); - bytes.put(tag.serialize_data()); + w.write_u8_be(tag.get_type_id())?; + NbtTag::String(name.clone()).serialize_data(w)?; + tag.serialize_data(w)?; } - bytes.put_u8(END_ID); - bytes.freeze() - } - - pub fn serialize_content_to_writer(&self, mut writer: W) -> std::io::Result<()> { - writer.write_all(&self.serialize_content())?; + w.write_u8_be(END_ID)?; Ok(()) } @@ -139,7 +179,7 @@ impl NbtCompound { self.get(name).and_then(|tag| tag.extract_string()) } - pub fn get_list(&self, name: &str) -> Option<&Vec> { + pub fn get_list(&self, name: &str) -> Option<&[NbtTag]> { self.get(name).and_then(|tag| tag.extract_list()) } @@ -147,11 +187,11 @@ impl NbtCompound { self.get(name).and_then(|tag| tag.extract_compound()) } - pub fn get_int_array(&self, name: &str) -> Option<&Vec> { + pub fn get_int_array(&self, name: &str) -> Option<&[i32]> { self.get(name).and_then(|tag| tag.extract_int_array()) } - pub fn get_long_array(&self, name: &str) -> Option<&Vec> { + pub fn get_long_array(&self, name: &str) -> Option<&[i64]> { self.get(name).and_then(|tag| tag.extract_long_array()) } } diff --git a/pumpkin-nbt/src/deserializer.rs b/pumpkin-nbt/src/deserializer.rs index 80e3b96f4..e6432a187 100644 --- a/pumpkin-nbt/src/deserializer.rs +++ b/pumpkin-nbt/src/deserializer.rs @@ -1,132 +1,328 @@ use crate::*; -use bytes::Buf; -use serde::de::{self, DeserializeSeed, MapAccess, SeqAccess, Visitor}; +use io::Read; +use serde::de::{self, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor}; use serde::{forward_to_deserialize_any, Deserialize}; -use std::io::Cursor; pub type Result = std::result::Result; #[derive(Debug)] -pub struct Deserializer<'de, T> { - input: &'de mut T, - tag_to_deserialize: Option, +pub struct ReadAdaptor { + reader: R, +} + +impl ReadAdaptor { + pub fn new(r: R) -> Self { + Self { reader: r } + } +} + +impl ReadAdaptor { + pub fn skip_bytes(&mut self, count: u64) -> Result<()> { + let _ = io::copy(&mut self.reader.by_ref().take(count), &mut io::sink()) + .map_err(Error::Incomplete)?; + Ok(()) + } + + //TODO: Macroize this + pub fn get_u8_be(&mut self) -> Result { + let mut buf = [0u8]; + self.reader + .read_exact(&mut buf) + .map_err(Error::Incomplete)?; + + Ok(u8::from_be_bytes(buf)) + } + + pub fn get_i8_be(&mut self) -> Result { + let mut buf = [0u8]; + self.reader + .read_exact(&mut buf) + .map_err(Error::Incomplete)?; + + Ok(i8::from_be_bytes(buf)) + } + + pub fn get_i16_be(&mut self) -> Result { + let mut buf = [0u8; 2]; + self.reader + .read_exact(&mut buf) + .map_err(Error::Incomplete)?; + + Ok(i16::from_be_bytes(buf)) + } + + pub fn get_u16_be(&mut self) -> Result { + let mut buf = [0u8; 2]; + self.reader + .read_exact(&mut buf) + .map_err(Error::Incomplete)?; + + Ok(u16::from_be_bytes(buf)) + } + + pub fn get_i32_be(&mut self) -> Result { + let mut buf = [0u8; 4]; + self.reader + .read_exact(&mut buf) + .map_err(Error::Incomplete)?; + + Ok(i32::from_be_bytes(buf)) + } + + pub fn get_i64_be(&mut self) -> Result { + let mut buf = [0u8; 8]; + self.reader + .read_exact(&mut buf) + .map_err(Error::Incomplete)?; + + Ok(i64::from_be_bytes(buf)) + } + + pub fn get_f32_be(&mut self) -> Result { + let mut buf = [0u8; 4]; + self.reader + .read_exact(&mut buf) + .map_err(Error::Incomplete)?; + + Ok(f32::from_be_bytes(buf)) + } + + pub fn get_f64_be(&mut self) -> Result { + let mut buf = [0u8; 8]; + self.reader + .read_exact(&mut buf) + .map_err(Error::Incomplete)?; + + Ok(f64::from_be_bytes(buf)) + } + + pub fn read_boxed_slice(&mut self, count: usize) -> Result> { + let mut buf = vec![0u8; count]; + self.reader + .read_exact(&mut buf) + .map_err(Error::Incomplete)?; + + Ok(buf.into()) + } +} + +#[derive(Debug)] +pub struct Deserializer { + input: ReadAdaptor, + tag_to_deserialize_stack: Vec, + // Yes, this breaks with recursion. Just an attempt at a sanity check + in_list: bool, is_named: bool, } -impl<'de, T: Buf> Deserializer<'de, T> { - pub fn new(input: &'de mut T, is_named: bool) -> Self { +impl Deserializer { + pub fn new(input: R, is_named: bool) -> Self { Deserializer { - input, - tag_to_deserialize: None, + input: ReadAdaptor { reader: input }, + tag_to_deserialize_stack: Vec::new(), + in_list: false, is_named, } } } /// Deserializes struct using Serde Deserializer from normal NBT -pub fn from_bytes<'a, T>(s: &'a mut impl Buf) -> Result +pub fn from_bytes<'a, T>(r: impl Read) -> Result where T: Deserialize<'a>, { - let mut deserializer = Deserializer::new(s, true); + let mut deserializer = Deserializer::new(r, true); T::deserialize(&mut deserializer) } -/// Deserializes struct using Serde Deserializer from normal NBT -pub fn from_bytes_unnamed<'a, T>(s: &'a mut impl Buf) -> Result +/// Deserializes struct using Serde Deserializer from network NBT +pub fn from_bytes_unnamed<'a, T>(r: impl Read) -> Result where T: Deserialize<'a>, { - let mut deserializer = Deserializer::new(s, false); + let mut deserializer = Deserializer::new(r, false); T::deserialize(&mut deserializer) } -pub fn from_cursor<'a, T>(cursor: &'a mut Cursor<&[u8]>) -> Result -where - T: Deserialize<'a>, -{ - let mut deserializer = Deserializer::new(cursor, true); - T::deserialize(&mut deserializer) -} +impl<'de, R: Read> de::Deserializer<'de> for &mut Deserializer { + type Error = Error; -pub fn from_cursor_unnamed<'a, T>(cursor: &'a mut Cursor<&[u8]>) -> Result -where - T: Deserialize<'a>, -{ - let mut deserializer = Deserializer::new(cursor, false); - T::deserialize(&mut deserializer) -} + forward_to_deserialize_any! { + i8 i16 i32 i64 f32 f64 char str string unit unit_struct seq tuple tuple_struct + bytes newtype_struct byte_buf + } -impl<'de, T: Buf> de::Deserializer<'de> for &mut Deserializer<'de, T> { - type Error = Error; + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let Some(tag) = self.tag_to_deserialize_stack.pop() else { + return Err(Error::SerdeError("Ignoring nothing!".to_string())); + }; - forward_to_deserialize_any!(i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 seq char str string bytes byte_buf tuple tuple_struct enum ignored_any unit unit_struct option newtype_struct); + NbtTag::skip_data(&mut self.input, tag)?; + visitor.visit_unit() + } fn deserialize_any(self, visitor: V) -> Result where V: Visitor<'de>, { - let tag_to_deserialize = self.tag_to_deserialize.unwrap(); - - let list_type = match tag_to_deserialize { - LIST_ID => Some(self.input.get_u8()), - INT_ARRAY_ID => Some(INT_ID), - LONG_ARRAY_ID => Some(LONG_ID), - BYTE_ARRAY_ID => Some(BYTE_ID), - _ => None, + let Some(tag_to_deserialize) = self.tag_to_deserialize_stack.pop() else { + return Err(Error::SerdeError( + "The top level must be a component (e.g. a struct)".to_string(), + )); }; - if let Some(list_type) = list_type { - let remaining_values = self.input.get_u32(); - return visitor.visit_seq(ListAccess { - de: self, - list_type, - remaining_values, - }); + match tag_to_deserialize { + END_ID => Err(Error::SerdeError( + "Trying to deserialize an END tag!".to_string(), + )), + LIST_ID | INT_ARRAY_ID | LONG_ARRAY_ID | BYTE_ARRAY_ID => { + let list_type = match tag_to_deserialize { + LIST_ID => self.input.get_u8_be()?, + INT_ARRAY_ID => INT_ID, + LONG_ARRAY_ID => LONG_ID, + BYTE_ARRAY_ID => BYTE_ID, + _ => unreachable!(), + }; + + let remaining_values = self.input.get_i32_be()?; + if remaining_values < 0 { + return Err(Error::NegativeLength(remaining_values)); + } + + let result = visitor.visit_seq(ListAccess { + de: self, + list_type, + remaining_values: remaining_values as usize, + })?; + Ok(result) + } + COMPOUND_ID => visitor.visit_map(CompoundAccess { de: self }), + _ => { + let result = match NbtTag::deserialize_data(&mut self.input, tag_to_deserialize)? { + NbtTag::Byte(value) => visitor.visit_i8::(value)?, + NbtTag::Short(value) => visitor.visit_i16::(value)?, + NbtTag::Int(value) => visitor.visit_i32::(value)?, + NbtTag::Long(value) => visitor.visit_i64::(value)?, + NbtTag::Float(value) => visitor.visit_f32::(value)?, + NbtTag::Double(value) => visitor.visit_f64::(value)?, + NbtTag::String(value) => visitor.visit_string::(value)?, + _ => unreachable!(), + }; + Ok(result) + } } + } - let result: Result = Ok( - match NbtTag::deserialize_data(self.input, tag_to_deserialize)? { - NbtTag::Byte(value) => visitor.visit_i8::(value)?, - NbtTag::Short(value) => visitor.visit_i16::(value)?, - NbtTag::Int(value) => visitor.visit_i32::(value)?, - NbtTag::Long(value) => visitor.visit_i64::(value)?, - NbtTag::Float(value) => visitor.visit_f32::(value)?, - NbtTag::Double(value) => visitor.visit_f64::(value)?, - NbtTag::String(value) => visitor.visit_string::(value)?, - _ => unreachable!(), - }, - ); - self.tag_to_deserialize = None; - result + fn deserialize_u8(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.in_list { + let value = self.input.get_u8_be()?; + visitor.visit_u8::(value) + } else { + Err(Error::UnsupportedType( + "u8; NBT only supports signed values".to_string(), + )) + } + } + + fn deserialize_u16(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(Error::UnsupportedType( + "u16; NBT only supports signed values".to_string(), + )) + } + + fn deserialize_u32(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(Error::UnsupportedType( + "u32; NBT only supports signed values".to_string(), + )) + } + + fn deserialize_u64(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(Error::UnsupportedType( + "u64; NBT only supports signed values".to_string(), + )) } fn deserialize_bool(self, visitor: V) -> Result where V: Visitor<'de>, { - if self.tag_to_deserialize.unwrap() == BYTE_ID { - let value = self.input.get_u8(); - if value != 0 { - return visitor.visit_bool(true); + if let Some(tag_id) = self.tag_to_deserialize_stack.last() { + if *tag_id == BYTE_ID { + let value = self.input.get_u8_be()?; + if value != 0 { + visitor.visit_bool(true) + } else { + visitor.visit_bool(false) + } + } else { + Err(Error::UnsupportedType(format!( + "Non-byte bool (found type {})", + tag_id + ))) } + } else { + Err(Error::SerdeError( + "Wanted to deserialize a bool, but there was no type hint on the stack!" + .to_string(), + )) } - visitor.visit_bool(false) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + let variant = get_nbt_string(&mut self.input)?; + visitor.visit_enum(variant.into_deserializer()) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + // None is not encoded, so no need for it + visitor.visit_some(self) } fn deserialize_map(self, visitor: V) -> Result where V: Visitor<'de>, { - if self.tag_to_deserialize.is_none() { - let next_byte = self.input.get_u8(); + if let Some(tag_id) = self.tag_to_deserialize_stack.last() { + if *tag_id != COMPOUND_ID { + return Err(Error::SerdeError( + "Trying to deserialize a map without a compound ID".to_string(), + )); + } + } else { + let next_byte = self.input.get_u8_be()?; if next_byte != COMPOUND_ID { return Err(Error::NoRootCompound(next_byte)); } if self.is_named { // Consume struct name - NbtTag::deserialize(self.input)?; + let _ = get_nbt_string(&mut self.input)?; } } @@ -150,7 +346,7 @@ impl<'de, T: Buf> de::Deserializer<'de> for &mut Deserializer<'de, T> { where V: Visitor<'de>, { - let str = get_nbt_string(&mut self.input).map_err(|_| Error::Cesu8DecodingError)?; + let str = get_nbt_string(&mut self.input)?; visitor.visit_string(str) } @@ -159,25 +355,25 @@ impl<'de, T: Buf> de::Deserializer<'de> for &mut Deserializer<'de, T> { } } -struct CompoundAccess<'a, 'de: 'a, T: Buf> { - de: &'a mut Deserializer<'de, T>, +struct CompoundAccess<'a, R: Read> { + de: &'a mut Deserializer, } -impl<'de, T: Buf> MapAccess<'de> for CompoundAccess<'_, 'de, T> { +impl<'de, R: Read> MapAccess<'de> for CompoundAccess<'_, R> { type Error = Error; fn next_key_seed(&mut self, seed: K) -> Result> where K: DeserializeSeed<'de>, { - let tag = self.de.input.get_u8(); - self.de.tag_to_deserialize = Some(tag); + let tag = self.de.input.get_u8_be()?; + self.de.tag_to_deserialize_stack.push(tag); if tag == END_ID { return Ok(None); } - seed.deserialize(&mut *self.de).map(Some) + seed.deserialize(MapKey { de: self.de }).map(Some) } fn next_value_seed(&mut self, seed: V) -> Result @@ -188,15 +384,40 @@ impl<'de, T: Buf> MapAccess<'de> for CompoundAccess<'_, 'de, T> { } } -struct ListAccess<'a, 'de: 'a, T: Buf> { - de: &'a mut Deserializer<'de, T>, - remaining_values: u32, +struct MapKey<'a, R: Read> { + de: &'a mut Deserializer, +} + +impl<'de, R: Read> de::Deserializer<'de> for MapKey<'_, R> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let key = get_nbt_string(&mut self.de.input)?; + visitor.visit_string(key) + } + + forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit unit_struct seq tuple tuple_struct map + struct identifier ignored_any bytes enum newtype_struct byte_buf option + } +} + +struct ListAccess<'a, R: Read> { + de: &'a mut Deserializer, + remaining_values: usize, list_type: u8, } -impl<'de, T: Buf> SeqAccess<'de> for ListAccess<'_, 'de, T> { +impl<'de, R: Read> SeqAccess<'de> for ListAccess<'_, R> { type Error = Error; + fn size_hint(&self) -> Option { + Some(self.remaining_values) + } + fn next_element_seed(&mut self, seed: E) -> Result> where E: DeserializeSeed<'de>, @@ -206,7 +427,11 @@ impl<'de, T: Buf> SeqAccess<'de> for ListAccess<'_, 'de, T> { } self.remaining_values -= 1; - self.de.tag_to_deserialize = Some(self.list_type); - seed.deserialize(&mut *self.de).map(Some) + self.de.tag_to_deserialize_stack.push(self.list_type); + self.de.in_list = true; + let result = seed.deserialize(&mut *self.de).map(Some); + self.de.in_list = false; + + result } } diff --git a/pumpkin-nbt/src/lib.rs b/pumpkin-nbt/src/lib.rs index 36d724e96..b299aa386 100644 --- a/pumpkin-nbt/src/lib.rs +++ b/pumpkin-nbt/src/lib.rs @@ -1,14 +1,14 @@ use std::{ fmt::Display, - io::{self, Cursor, Write}, + io::{self, Read, Write}, ops::Deref, }; -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use cesu8::Cesu8DecodingError; +use bytes::Bytes; use compound::NbtCompound; +use deserializer::ReadAdaptor; use serde::{de, ser}; -use serde::{Deserialize, Deserializer}; +use serializer::WriteAdaptor; use tag::NbtTag; use thiserror::Error; @@ -19,19 +19,19 @@ pub mod tag; // This NBT crate is inspired from CrabNBT -pub const END_ID: u8 = 0; -pub const BYTE_ID: u8 = 1; -pub const SHORT_ID: u8 = 2; -pub const INT_ID: u8 = 3; -pub const LONG_ID: u8 = 4; -pub const FLOAT_ID: u8 = 5; -pub const DOUBLE_ID: u8 = 6; -pub const BYTE_ARRAY_ID: u8 = 7; -pub const STRING_ID: u8 = 8; -pub const LIST_ID: u8 = 9; -pub const COMPOUND_ID: u8 = 10; -pub const INT_ARRAY_ID: u8 = 11; -pub const LONG_ARRAY_ID: u8 = 12; +pub const END_ID: u8 = 0x00; +pub const BYTE_ID: u8 = 0x01; +pub const SHORT_ID: u8 = 0x02; +pub const INT_ID: u8 = 0x03; +pub const LONG_ID: u8 = 0x04; +pub const FLOAT_ID: u8 = 0x05; +pub const DOUBLE_ID: u8 = 0x06; +pub const BYTE_ARRAY_ID: u8 = 0x07; +pub const STRING_ID: u8 = 0x08; +pub const LIST_ID: u8 = 0x09; +pub const COMPOUND_ID: u8 = 0x0A; +pub const INT_ARRAY_ID: u8 = 0x0B; +pub const LONG_ARRAY_ID: u8 = 0x0C; #[derive(Error, Debug)] pub enum Error { @@ -45,6 +45,12 @@ pub enum Error { SerdeError(String), #[error("NBT doesn't support this type {0}")] UnsupportedType(String), + #[error("NBT reading was cut short {0}")] + Incomplete(io::Error), + #[error("Negative list length {0}")] + NegativeLength(i32), + #[error("Length too large {0}")] + LargeLength(usize), } impl ser::Error for Error { @@ -73,26 +79,28 @@ impl Nbt { } } - pub fn read(bytes: &mut impl Buf) -> Result { - let tag_type_id = bytes.get_u8(); + pub fn read(reader: &mut ReadAdaptor) -> Result + where + R: Read, + { + let tag_type_id = reader.get_u8_be()?; if tag_type_id != COMPOUND_ID { return Err(Error::NoRootCompound(tag_type_id)); } Ok(Nbt { - name: get_nbt_string(bytes).map_err(|_| Error::Cesu8DecodingError)?, - root_tag: NbtCompound::deserialize_content(bytes)?, + name: get_nbt_string(reader)?, + root_tag: NbtCompound::deserialize_content(reader)?, }) } - pub fn read_from_cursor(cursor: &mut Cursor<&[u8]>) -> Result { - Self::read(cursor) - } - /// Reads NBT tag, that doesn't contain the name of root compound. - pub fn read_unnamed(bytes: &mut impl Buf) -> Result { - let tag_type_id = bytes.get_u8(); + pub fn read_unnamed(reader: &mut ReadAdaptor) -> Result + where + R: Read, + { + let tag_type_id = reader.get_u8_be()?; if tag_type_id != COMPOUND_ID { return Err(Error::NoRootCompound(tag_type_id)); @@ -100,21 +108,20 @@ impl Nbt { Ok(Nbt { name: String::new(), - root_tag: NbtCompound::deserialize_content(bytes) - .map_err(|_| Error::Cesu8DecodingError)?, + root_tag: NbtCompound::deserialize_content(reader)?, }) } - pub fn read_unnamed_from_cursor(cursor: &mut Cursor<&[u8]>) -> Result { - Self::read_unnamed(cursor) - } - pub fn write(&self) -> Bytes { - let mut bytes = BytesMut::new(); - bytes.put_u8(COMPOUND_ID); - bytes.put(NbtTag::String(self.name.to_string()).serialize_data()); - bytes.put(self.root_tag.serialize_content()); - bytes.freeze() + let mut bytes = Vec::new(); + let mut writer = WriteAdaptor::new(&mut bytes); + writer.write_u8_be(COMPOUND_ID).unwrap(); + NbtTag::String(self.name.to_string()) + .serialize_data(&mut writer) + .unwrap(); + self.root_tag.serialize_content(&mut writer).unwrap(); + + bytes.into() } pub fn write_to_writer(&self, mut writer: W) -> Result<(), io::Error> { @@ -124,10 +131,13 @@ impl Nbt { /// Writes NBT tag, without name of root compound. pub fn write_unnamed(&self) -> Bytes { - let mut bytes = BytesMut::new(); - bytes.put_u8(COMPOUND_ID); - bytes.put(self.root_tag.serialize_content()); - bytes.freeze() + let mut bytes = Vec::new(); + let mut writer = WriteAdaptor::new(&mut bytes); + + writer.write_u8_be(COMPOUND_ID).unwrap(); + self.root_tag.serialize_content(&mut writer).unwrap(); + + bytes.into() } pub fn write_unnamed_to_writer(&self, mut writer: W) -> Result<(), io::Error> { @@ -166,49 +176,46 @@ impl AsMut for Nbt { } } -pub fn get_nbt_string(bytes: &mut impl Buf) -> Result { - let len = bytes.get_u16() as usize; - let string_bytes = bytes.copy_to_bytes(len); - let string = cesu8::from_java_cesu8(&string_bytes)?; +pub fn get_nbt_string(bytes: &mut ReadAdaptor) -> Result { + let len = bytes.get_u16_be()? as usize; + let string_bytes = bytes.read_boxed_slice(len)?; + let string = cesu8::from_java_cesu8(&string_bytes).map_err(|_| Error::Cesu8DecodingError)?; Ok(string.to_string()) } +pub(crate) const NBT_ARRAY_TAG: &str = "__nbt_array"; +pub(crate) const NBT_INT_ARRAY_TAG: &str = "__nbt_int_array"; +pub(crate) const NBT_LONG_ARRAY_TAG: &str = "__nbt_long_array"; +pub(crate) const NBT_BYTE_ARRAY_TAG: &str = "__nbt_byte_array"; + macro_rules! impl_array { ($name:ident, $variant:expr) => { - pub struct $name; - - impl $name { - pub fn serialize(input: T, serializer: S) -> Result - where - T: serde::Serialize, - S: serde::Serializer, - { - serializer.serialize_newtype_variant("nbt_array", 0, $variant, &input) - } - - pub fn deserialize<'de, T, D>(deserializer: D) -> Result - where - T: Deserialize<'de>, - D: Deserializer<'de>, - { - T::deserialize(deserializer) - } + pub fn $name(input: T, serializer: S) -> Result + where + T: serde::Serialize, + S: serde::Serializer, + { + serializer.serialize_newtype_variant(NBT_ARRAY_TAG, 0, $variant, &input) } }; } -impl_array!(IntArray, "int"); -impl_array!(LongArray, "long"); -impl_array!(BytesArray, "byte"); +impl_array!(nbt_int_array, NBT_INT_ARRAY_TAG); +impl_array!(nbt_long_array, NBT_LONG_ARRAY_TAG); +impl_array!(nbt_byte_array, NBT_BYTE_ARRAY_TAG); #[cfg(test)] mod test { - use serde::{Deserialize, Serialize}; - use crate::BytesArray; - use crate::IntArray; - use crate::LongArray; + use crate::deserializer::from_bytes; + use crate::nbt_byte_array; + use crate::nbt_int_array; + use crate::nbt_long_array; + use crate::serializer::to_bytes; + use crate::serializer::to_bytes_named; + use crate::Error; use crate::{deserializer::from_bytes_unnamed, serializer::to_bytes_unnamed}; + use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, PartialEq, Debug)] struct Test { @@ -230,19 +237,21 @@ mod test { float: 1.00, string: "Hello test".to_string(), }; - let mut bytes = to_bytes_unnamed(&test).unwrap(); - let recreated_struct: Test = from_bytes_unnamed(&mut bytes).unwrap(); + + let mut bytes = Vec::new(); + to_bytes_unnamed(&test, &mut bytes).unwrap(); + let recreated_struct: Test = from_bytes_unnamed(&bytes[..]).unwrap(); assert_eq!(test, recreated_struct); } #[derive(Serialize, Deserialize, PartialEq, Debug)] struct TestArray { - #[serde(with = "BytesArray")] + #[serde(serialize_with = "nbt_byte_array")] byte_array: Vec, - #[serde(with = "IntArray")] + #[serde(serialize_with = "nbt_int_array")] int_array: Vec, - #[serde(with = "LongArray")] + #[serde(serialize_with = "nbt_long_array")] long_array: Vec, } @@ -253,9 +262,260 @@ mod test { int_array: vec![13, 1321, 2], long_array: vec![1, 0, 200301, 1], }; - let mut bytes = to_bytes_unnamed(&test).unwrap(); - let recreated_struct: TestArray = from_bytes_unnamed(&mut bytes).unwrap(); + + let mut bytes = Vec::new(); + to_bytes_unnamed(&test, &mut bytes).unwrap(); + let recreated_struct: TestArray = from_bytes_unnamed(&bytes[..]).unwrap(); + + assert_eq!(test, recreated_struct); + } + + #[test] + fn test_simple_ser_de_named() { + let name = String::from("Test"); + let test = Test { + byte: 123, + short: 1342, + int: 4313, + long: 34, + float: 1.00, + string: "Hello test".to_string(), + }; + + let mut bytes = Vec::new(); + to_bytes_named(&test, name, &mut bytes).unwrap(); + let recreated_struct: Test = from_bytes(&bytes[..]).unwrap(); assert_eq!(test, recreated_struct); } + + #[test] + fn test_simple_ser_de_array_named() { + let name = String::from("Test"); + let test = TestArray { + byte_array: vec![0, 3, 2], + int_array: vec![13, 1321, 2], + long_array: vec![1, 0, 200301, 1], + }; + + let mut bytes = Vec::new(); + to_bytes_named(&test, name, &mut bytes).unwrap(); + let recreated_struct: TestArray = from_bytes(&bytes[..]).unwrap(); + + assert_eq!(test, recreated_struct); + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Egg { + food: String, + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Breakfast { + food: Egg, + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct TestList { + option: Option, + nested_compound: Breakfast, + compounds: Vec, + list_string: Vec, + empty: Vec, + } + + #[test] + fn test_list() { + let test1 = Test { + byte: 123, + short: 1342, + int: 4313, + long: 34, + float: 1.00, + string: "Hello test".to_string(), + }; + + let test2 = Test { + byte: 13, + short: 342, + int: -4313, + long: -132334, + float: -69.420, + string: "Hello compounds".to_string(), + }; + + let list_compound = TestList { + option: Some(Egg { + food: "Skibid".to_string(), + }), + nested_compound: Breakfast { + food: Egg { + food: "Over easy".to_string(), + }, + }, + compounds: vec![test1, test2], + list_string: vec!["".to_string(), "abcbcbcbbc".to_string()], + empty: vec![], + }; + + let mut bytes = Vec::new(); + to_bytes_unnamed(&list_compound, &mut bytes).unwrap(); + let recreated_struct: TestList = from_bytes_unnamed(&bytes[..]).unwrap(); + assert_eq!(list_compound, recreated_struct); + } + + #[test] + fn test_list_named() { + let test1 = Test { + byte: 123, + short: 1342, + int: 4313, + long: 34, + float: 1.00, + string: "Hello test".to_string(), + }; + + let test2 = Test { + byte: 13, + short: 342, + int: -4313, + long: -132334, + float: -69.420, + string: "Hello compounds".to_string(), + }; + + let list_compound = TestList { + option: None, + nested_compound: Breakfast { + food: Egg { + food: "Over easy".to_string(), + }, + }, + compounds: vec![test1, test2], + list_string: vec!["".to_string(), "abcbcbcbbc".to_string()], + empty: vec![], + }; + + let mut bytes = Vec::new(); + to_bytes_named(&list_compound, "a".to_string(), &mut bytes).unwrap(); + let recreated_struct: TestList = from_bytes(&bytes[..]).unwrap(); + assert_eq!(list_compound, recreated_struct); + } + + #[test] + fn test_nbt_arrays() { + #[derive(Serialize)] + struct Tagged { + #[serde(serialize_with = "nbt_long_array")] + l: [i64; 1], + #[serde(serialize_with = "nbt_int_array")] + i: [i32; 1], + #[serde(serialize_with = "nbt_byte_array")] + b: [u8; 1], + } + + let value = Tagged { + l: [0], + i: [0], + b: [0], + }; + let expected_bytes = [ + 0x0A, // Component Tag + 0x00, 0x00, // Empty root name + 0x0C, // Long Array Type + 0x00, 0x01, // Key length + 0x6C, // Key (l) + 0x00, 0x00, 0x00, 0x01, // Array Length + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Value(s) + 0x0B, // Int Array Tag + 0x00, 0x01, // Key length + 0x69, // Key (i) + 0x00, 0x00, 0x00, 0x01, // Array Length + 0x00, 0x00, 0x00, 0x00, // Value(s) + 0x07, // Byte Array Tag + 0x00, 0x01, // Key length + 0x62, // Key (b) + 0x00, 0x00, 0x00, 0x01, // Array Length + 0x00, // Value(s) + 0x00, // End Tag + ]; + + let mut bytes = Vec::new(); + to_bytes(&value, &mut bytes).unwrap(); + assert_eq!(bytes, expected_bytes); + + #[derive(Serialize)] + struct NotTagged { + l: [i64; 1], + i: [i32; 1], + b: [u8; 1], + } + + let value = NotTagged { + l: [0], + i: [0], + b: [0], + }; + let expected_bytes = [ + 0x0A, // Component Tag + 0x00, 0x00, // Empty root name + 0x09, // List Tag + 0x00, 0x01, // Key length + 0x6C, // Key (l) + 0x04, // Array Type + 0x00, 0x00, 0x00, 0x01, // Array Length + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Value(s) + 0x09, // List Tag + 0x00, 0x01, // Key length + 0x69, // Key (i) + 0x03, // Array Type + 0x00, 0x00, 0x00, 0x01, // Array Length + 0x00, 0x00, 0x00, 0x00, // Value(s) + 0x09, // List Tag + 0x00, 0x01, // Key length + 0x62, // Key (b) + 0x01, // Array Type + 0x00, 0x00, 0x00, 0x01, // Array Length + 0x00, // Value(s) + 0x00, // End Tag + ]; + + let mut bytes = Vec::new(); + to_bytes(&value, &mut bytes).unwrap(); + assert_eq!(bytes, expected_bytes); + } + + #[test] + fn test_tuple_fail() { + #[derive(Serialize)] + struct BadData { + x: (i32, i64), + } + + let value = BadData { x: (0, 0) }; + let mut bytes = Vec::new(); + let err = to_bytes(&value, &mut bytes); + + match err { + Err(Error::SerdeError(_)) => (), + _ => panic!("Expected to fail serialization!"), + }; + } + + #[test] + fn test_tuple_ok() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct GoodData { + x: (i32, i32), + } + + let value = GoodData { x: (1, 2) }; + let mut bytes = Vec::new(); + to_bytes(&value, &mut bytes).unwrap(); + + let reconstructed = from_bytes(&bytes[..]).unwrap(); + assert_eq!(value, reconstructed); + } + + // TODO: More robust tests } diff --git a/pumpkin-nbt/src/serializer.rs b/pumpkin-nbt/src/serializer.rs index c150d5f32..fdc6f9c20 100644 --- a/pumpkin-nbt/src/serializer.rs +++ b/pumpkin-nbt/src/serializer.rs @@ -1,4 +1,3 @@ -use bytes::{BufMut, BytesMut}; use serde::ser::Impossible; use serde::{ser, Serialize}; use std::io::Write; @@ -6,20 +5,95 @@ use std::io::Write; use crate::tag::NbtTag; use crate::{ Error, BYTE_ARRAY_ID, BYTE_ID, COMPOUND_ID, DOUBLE_ID, END_ID, FLOAT_ID, INT_ARRAY_ID, INT_ID, - LIST_ID, LONG_ARRAY_ID, LONG_ID, SHORT_ID, STRING_ID, + LIST_ID, LONG_ARRAY_ID, LONG_ID, NBT_ARRAY_TAG, NBT_BYTE_ARRAY_TAG, NBT_INT_ARRAY_TAG, + NBT_LONG_ARRAY_TAG, SHORT_ID, STRING_ID, }; pub type Result = std::result::Result; -pub trait SerializeChild { - fn serialize_child(&self, serializer: S) -> std::result::Result - where - S: ser::Serializer; +#[derive(Debug)] +pub struct WriteAdaptor { + writer: W, +} + +impl WriteAdaptor { + pub fn new(w: W) -> Self { + Self { writer: w } + } +} + +impl WriteAdaptor { + //TODO: Macroize this + pub fn write_u8_be(&mut self, value: u8) -> Result<()> { + let buf = value.to_be_bytes(); + self.writer.write_all(&buf).map_err(Error::Incomplete)?; + Ok(()) + } + + pub fn write_i8_be(&mut self, value: i8) -> Result<()> { + let buf = value.to_be_bytes(); + self.writer.write_all(&buf).map_err(Error::Incomplete)?; + Ok(()) + } + + pub fn write_u16_be(&mut self, value: u16) -> Result<()> { + let buf = value.to_be_bytes(); + self.writer.write_all(&buf).map_err(Error::Incomplete)?; + Ok(()) + } + + pub fn write_i16_be(&mut self, value: i16) -> Result<()> { + let buf = value.to_be_bytes(); + self.writer.write_all(&buf).map_err(Error::Incomplete)?; + Ok(()) + } + + pub fn write_i32_be(&mut self, value: i32) -> Result<()> { + let buf = value.to_be_bytes(); + self.writer.write_all(&buf).map_err(Error::Incomplete)?; + Ok(()) + } + + pub fn write_i64_be(&mut self, value: i64) -> Result<()> { + let buf = value.to_be_bytes(); + self.writer.write_all(&buf).map_err(Error::Incomplete)?; + Ok(()) + } + + pub fn write_f32_be(&mut self, value: f32) -> Result<()> { + let buf = value.to_be_bytes(); + self.writer.write_all(&buf).map_err(Error::Incomplete)?; + Ok(()) + } + + pub fn write_f64_be(&mut self, value: f64) -> Result<()> { + let buf = value.to_be_bytes(); + self.writer.write_all(&buf).map_err(Error::Incomplete)?; + Ok(()) + } + + pub fn write_slice(&mut self, value: &[u8]) -> Result<()> { + self.writer.write_all(value).map_err(Error::Incomplete)?; + Ok(()) + } } -pub struct Serializer { - output: BytesMut, +pub struct Serializer { + output: WriteAdaptor, state: State, + handled_root: bool, + expected_list_tag: u8, +} + +impl Serializer { + pub fn new(output: W, name: Option) -> Self { + Self { + output: WriteAdaptor::new(output), + state: State::Root(name), + handled_root: false, + expected_list_tag: 0, + } + } } // NBT has a different order of things, then most other formats @@ -31,22 +105,28 @@ enum State { Named(String), // Used by maps, to check if key is String MapKey, - FirstListElement { len: i32 }, + FirstListElement { + len: i32, + }, ListElement, - Array { name: String, array_type: String }, + CheckedListElement, + Array { + name: String, + array_type: &'static str, + }, } -impl Serializer { +impl Serializer { fn parse_state(&mut self, tag: u8) -> Result<()> { match &mut self.state { State::Named(name) | State::Array { name, .. } => { - self.output.put_u8(tag); - self.output - .put(NbtTag::String(name.clone()).serialize_data()); + self.output.write_u8_be(tag)?; + NbtTag::String(name.clone()).serialize_data(&mut self.output)?; } State::FirstListElement { len } => { - self.output.put_u8(tag); - self.output.put_i32(*len); + self.output.write_u8_be(tag)?; + self.output.write_i32_be(*len)?; + self.expected_list_tag = tag; } State::MapKey => { if tag != STRING_ID { @@ -55,76 +135,74 @@ impl Serializer { ))); } } - State::ListElement => {} - _ => return Err(Error::SerdeError("Invalid Serializer state!".to_string())), + State::ListElement => { + // Rust rules mandate this is all the same type + } + State::CheckedListElement => { + if tag != self.expected_list_tag { + return Err(Error::SerdeError(format!( + "List values must all be of the same type! Expected {} but found {}!", + self.expected_list_tag, tag + ))); + } + } + State::Root(root_name) => { + if self.handled_root { + return Err(Error::SerdeError( + "Invalid state: already handled root component!".to_string(), + )); + } else { + if tag != COMPOUND_ID { + return Err(Error::SerdeError(format!( + "Invalid state: root is not a compound! ({})", + tag + ))); + } + self.handled_root = true; + self.output.write_u8_be(tag)?; + if let Some(root_name) = root_name { + NbtTag::String(root_name.clone()).serialize_data(&mut self.output)?; + } + } + } }; Ok(()) } } -/// Serializes struct using Serde Serializer to unnamed (network) NBT (Exclusive to TextComponent) -pub fn to_bytes_text_component(value: &T) -> Result -where - T: SerializeChild, -{ - let mut serializer = Serializer { - output: BytesMut::new(), - state: State::Root(None), - }; - value.serialize_child(&mut serializer)?; - Ok(serializer.output) -} - /// Serializes struct using Serde Serializer to unnamed (network) NBT -pub fn to_bytes_unnamed(value: &T) -> Result +pub fn to_bytes_unnamed(value: &T, w: impl Write) -> Result<()> where T: Serialize, { - let mut serializer = Serializer { - output: BytesMut::new(), - state: State::Root(None), - }; + let mut serializer = Serializer::new(w, None); value.serialize(&mut serializer)?; - Ok(serializer.output) -} -/// Serializes struct using Serde Serializer to unnamed NBT -pub fn to_writer_unnamed(value: &T, mut writer: W) -> Result<()> -where - T: Serialize, - W: Write, -{ - writer.write_all(&to_bytes_unnamed(value)?).unwrap(); Ok(()) } /// Serializes struct using Serde Serializer to normal NBT -pub fn to_bytes(value: &T, name: String) -> Result +pub fn to_bytes_named(value: &T, name: String, w: impl Write) -> Result<()> where T: Serialize, { - let mut serializer = Serializer { - output: BytesMut::new(), - state: State::Root(Some(name)), - }; + let mut serializer = Serializer::new(w, Some(name)); value.serialize(&mut serializer)?; - Ok(serializer.output) + Ok(()) } -pub fn to_writer(value: &T, name: String, mut writer: W) -> Result<()> +pub fn to_bytes(value: &T, w: impl Write) -> Result<()> where T: Serialize, - W: Write, { - writer.write_all(&to_bytes(value, name)?).unwrap(); - Ok(()) + to_bytes_named(value, String::new(), w) } -impl ser::Serializer for &mut Serializer { +impl ser::Serializer for &mut Serializer { type Ok = (); type Error = Error; type SerializeSeq = Self; - type SerializeTuple = Impossible<(), Error>; + type SerializeTuple = Self; type SerializeTupleStruct = Impossible<(), Error>; type SerializeTupleVariant = Impossible<(), Error>; type SerializeMap = Self; @@ -138,61 +216,68 @@ impl ser::Serializer for &mut Serializer { fn serialize_i8(self, v: i8) -> Result<()> { self.parse_state(BYTE_ID)?; - self.output.put_i8(v); + self.output.write_i8_be(v)?; Ok(()) } fn serialize_i16(self, v: i16) -> Result<()> { self.parse_state(SHORT_ID)?; - self.output.put_i16(v); + self.output.write_i16_be(v)?; Ok(()) } fn serialize_i32(self, v: i32) -> Result<()> { self.parse_state(INT_ID)?; - self.output.put_i32(v); + self.output.write_i32_be(v)?; Ok(()) } fn serialize_i64(self, v: i64) -> Result<()> { self.parse_state(LONG_ID)?; - self.output.put_i64(v); + self.output.write_i64_be(v)?; Ok(()) } fn serialize_u8(self, v: u8) -> Result<()> { - self.parse_state(BYTE_ID)?; - self.output.put_u8(v); - Ok(()) + match self.state { + State::Named(_) => Err(Error::UnsupportedType( + "u8; NBT only supports signed values".to_string(), + )), + _ => { + self.parse_state(BYTE_ID)?; + self.output.write_u8_be(v)?; + Ok(()) + } + } } - fn serialize_u16(self, v: u16) -> Result<()> { - self.parse_state(SHORT_ID)?; - self.output.put_u16(v); - Ok(()) + fn serialize_u16(self, _v: u16) -> Result<()> { + Err(Error::UnsupportedType( + "u16; NBT only supports signed values".to_string(), + )) } - fn serialize_u32(self, v: u32) -> Result<()> { - self.parse_state(INT_ID)?; - self.output.put_u32(v); - Ok(()) + fn serialize_u32(self, _v: u32) -> Result<()> { + Err(Error::UnsupportedType( + "u32; NBT only supports signed values".to_string(), + )) } - fn serialize_u64(self, v: u64) -> Result<()> { - self.parse_state(LONG_ID)?; - self.output.put_u64(v); - Ok(()) + fn serialize_u64(self, _v: u64) -> Result<()> { + Err(Error::UnsupportedType( + "u64; NBT only supports signed values".to_string(), + )) } fn serialize_f32(self, v: f32) -> Result<()> { self.parse_state(FLOAT_ID)?; - self.output.put_f32(v); + self.output.write_f32_be(v)?; Ok(()) } fn serialize_f64(self, v: f64) -> Result<()> { self.parse_state(DOUBLE_ID)?; - self.output.put_f64(v); + self.output.write_f64_be(v)?; Ok(()) } @@ -202,21 +287,27 @@ impl ser::Serializer for &mut Serializer { fn serialize_str(self, v: &str) -> Result<()> { self.parse_state(STRING_ID)?; + if self.state == State::MapKey { self.state = State::Named(v.to_string()); - return Ok(()); + } else { + NbtTag::String(v.to_string()).serialize_data(&mut self.output)?; } - self.output - .put(NbtTag::String(v.to_string()).serialize_data()); Ok(()) } fn serialize_bytes(self, v: &[u8]) -> Result<()> { self.parse_state(LIST_ID)?; - self.output.put_u8(BYTE_ID); - self.output.put_i32(v.len() as i32); - self.output.put_slice(v); + self.output.write_u8_be(BYTE_ID)?; + + let len = v.len(); + if len > i32::MAX as usize { + return Err(Error::LargeLength(len)); + } + + self.output.write_i32_be(len as i32)?; + self.output.write_slice(v)?; Ok(()) } @@ -267,63 +358,69 @@ impl ser::Serializer for &mut Serializer { where T: ?Sized + Serialize, { - if name != "nbt_array" { - return Err(Error::SerdeError( - "new_type variant supports only nbt_array".to_string(), - )); + if name == NBT_ARRAY_TAG { + let name = match self.state { + State::Named(ref name) => name.clone(), + _ => return Err(Error::SerdeError("Invalid Serializer state!".to_string())), + }; + + self.state = State::Array { + name, + array_type: variant, + }; + } else { + return Err(Error::UnsupportedType("newtype variant".to_string())); } - - let name = match self.state { - State::Named(ref name) => name.clone(), - _ => return Err(Error::SerdeError("Invalid Serializer state!".to_string())), - }; - - self.state = State::Array { - name, - array_type: variant.to_string(), - }; - - value.serialize(self)?; - - Ok(()) + value.serialize(self) } fn serialize_seq(self, len: Option) -> Result { - if len.is_none() { + let Some(len) = len else { return Err(Error::SerdeError( "Length of the sequence must be known first!".to_string(), )); + }; + if len > i32::MAX as usize { + return Err(Error::LargeLength(len)); } match &mut self.state { State::Array { array_type, .. } => { - let id = match array_type.as_str() { - "byte" => BYTE_ARRAY_ID, - "int" => INT_ARRAY_ID, - "long" => LONG_ARRAY_ID, + let (id, expected_tag) = match *array_type { + NBT_BYTE_ARRAY_TAG => (BYTE_ARRAY_ID, BYTE_ID), + NBT_INT_ARRAY_TAG => (INT_ARRAY_ID, INT_ID), + NBT_LONG_ARRAY_TAG => (LONG_ARRAY_ID, LONG_ID), _ => { return Err(Error::SerdeError( "Array supports only byte, int, long".to_string(), )) } }; + self.parse_state(id)?; - self.output.put_i32(len.unwrap() as i32); - self.state = State::ListElement; + self.output.write_i32_be(len as i32)?; + + // We can mark anything as an nbt array list, so mark as needed to be checked + self.expected_list_tag = expected_tag; + self.state = State::CheckedListElement; } _ => { self.parse_state(LIST_ID)?; - self.state = State::FirstListElement { - len: len.unwrap() as i32, - }; + self.state = State::FirstListElement { len: len as i32 }; + if len == 0 { + // If we have no elements, FirstListElement state will never be invoked; so + // write the (unknown) list type and length here. + self.output.write_u8_be(END_ID)?; + self.output.write_i32_be(0)?; + } } } Ok(self) } - fn serialize_tuple(self, _len: usize) -> Result { - Err(Error::UnsupportedType("tuple".to_string())) + fn serialize_tuple(self, len: usize) -> Result { + self.serialize_seq(Some(len)) } fn serialize_tuple_struct( @@ -345,35 +442,12 @@ impl ser::Serializer for &mut Serializer { } fn serialize_map(self, _len: Option) -> Result { - if let State::FirstListElement { .. } = self.state { - self.parse_state(COMPOUND_ID)?; - } else if let State::ListElement = self.state { - return Ok(self); - } else { - self.output.put_u8(COMPOUND_ID); - } + self.parse_state(COMPOUND_ID)?; Ok(self) } fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { - self.output.put_u8(COMPOUND_ID); - - match &mut self.state { - State::Root(root_name) => { - if let Some(root_name) = root_name { - self.output - .put(NbtTag::String(root_name.clone()).serialize_data()); - } - } - State::Named(string) => { - self.output - .put(NbtTag::String(string.clone()).serialize_data()); - } - _ => { - unimplemented!() - } - } - + self.parse_state(COMPOUND_ID)?; Ok(self) } @@ -392,7 +466,25 @@ impl ser::Serializer for &mut Serializer { } } -impl ser::SerializeSeq for &mut Serializer { +impl ser::SerializeTuple for &mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> std::result::Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(&mut **self)?; + self.state = State::CheckedListElement; + Ok(()) + } + + fn end(self) -> Result<()> { + Ok(()) + } +} + +impl ser::SerializeSeq for &mut Serializer { type Ok = (); type Error = Error; @@ -410,7 +502,7 @@ impl ser::SerializeSeq for &mut Serializer { } } -impl ser::SerializeStruct for &mut Serializer { +impl ser::SerializeStruct for &mut Serializer { type Ok = (); type Error = Error; @@ -423,12 +515,12 @@ impl ser::SerializeStruct for &mut Serializer { } fn end(self) -> Result<()> { - self.output.put_u8(END_ID); + self.output.write_u8_be(END_ID)?; Ok(()) } } -impl ser::SerializeMap for &mut Serializer { +impl ser::SerializeMap for &mut Serializer { type Ok = (); type Error = Error; @@ -448,7 +540,7 @@ impl ser::SerializeMap for &mut Serializer { } fn end(self) -> Result<()> { - self.output.put_u8(END_ID); + self.output.write_u8_be(END_ID)?; Ok(()) } } diff --git a/pumpkin-nbt/src/tag.rs b/pumpkin-nbt/src/tag.rs index 2d4b78ffb..1e7f48a3f 100644 --- a/pumpkin-nbt/src/tag.rs +++ b/pumpkin-nbt/src/tag.rs @@ -1,7 +1,7 @@ -use std::io::Cursor; - -use bytes::{Bytes, BytesMut}; use compound::NbtCompound; +use deserializer::ReadAdaptor; +use io::Read; +use serializer::WriteAdaptor; use crate::*; @@ -15,12 +15,12 @@ pub enum NbtTag { Long(i64) = LONG_ID, Float(f32) = FLOAT_ID, Double(f64) = DOUBLE_ID, - ByteArray(Bytes) = BYTE_ARRAY_ID, + ByteArray(Box<[u8]>) = BYTE_ARRAY_ID, String(String) = STRING_ID, - List(Vec) = LIST_ID, + List(Box<[NbtTag]>) = LIST_ID, Compound(NbtCompound) = COMPOUND_ID, - IntArray(Vec) = INT_ARRAY_ID, - LongArray(Vec) = LONG_ARRAY_ID, + IntArray(Box<[i32]>) = INT_ARRAY_ID, + LongArray(Box<[i64]>) = LONG_ARRAY_ID, } impl NbtTag { @@ -30,141 +30,241 @@ impl NbtTag { unsafe { *(self as *const Self as *const u8) } } - pub fn serialize(&self) -> Bytes { - let mut bytes = BytesMut::new(); - bytes.put_u8(self.get_type_id()); - bytes.put(self.serialize_data()); - bytes.freeze() + pub fn serialize(&self, w: &mut WriteAdaptor) -> serializer::Result<()> + where + W: Write, + { + w.write_u8_be(self.get_type_id())?; + self.serialize_data(w)?; + Ok(()) } - pub fn serialize_data(&self) -> Bytes { - let mut bytes = BytesMut::new(); + pub fn serialize_data(&self, w: &mut WriteAdaptor) -> serializer::Result<()> + where + W: Write, + { match self { NbtTag::End => {} - NbtTag::Byte(byte) => bytes.put_i8(*byte), - NbtTag::Short(short) => bytes.put_i16(*short), - NbtTag::Int(int) => bytes.put_i32(*int), - NbtTag::Long(long) => bytes.put_i64(*long), - NbtTag::Float(float) => bytes.put_f32(*float), - NbtTag::Double(double) => bytes.put_f64(*double), + NbtTag::Byte(byte) => w.write_i8_be(*byte)?, + NbtTag::Short(short) => w.write_i16_be(*short)?, + NbtTag::Int(int) => w.write_i32_be(*int)?, + NbtTag::Long(long) => w.write_i64_be(*long)?, + NbtTag::Float(float) => w.write_f32_be(*float)?, + NbtTag::Double(double) => w.write_f64_be(*double)?, NbtTag::ByteArray(byte_array) => { - bytes.put_i32(byte_array.len() as i32); - bytes.put_slice(byte_array); + let len = byte_array.len(); + if len > i32::MAX as usize { + return Err(Error::LargeLength(len)); + } + + w.write_i32_be(len as i32)?; + w.write_slice(byte_array)?; } NbtTag::String(string) => { let java_string = cesu8::to_java_cesu8(string); - bytes.put_u16(java_string.len() as u16); - bytes.put_slice(&java_string); + let len = java_string.len(); + if len > u16::MAX as usize { + return Err(Error::LargeLength(len)); + } + + w.write_u16_be(len as u16)?; + w.write_slice(&java_string)?; } NbtTag::List(list) => { - bytes.put_u8(list.first().unwrap_or(&NbtTag::End).get_type_id()); - bytes.put_i32(list.len() as i32); + let len = list.len(); + if len > i32::MAX as usize { + return Err(Error::LargeLength(len)); + } + + w.write_u8_be(list.first().unwrap_or(&NbtTag::End).get_type_id())?; + w.write_i32_be(len as i32)?; for nbt_tag in list { - bytes.put(nbt_tag.serialize_data()) + nbt_tag.serialize_data(w)?; } } NbtTag::Compound(compound) => { - bytes.put(compound.serialize_content()); + compound.serialize_content(w)?; } NbtTag::IntArray(int_array) => { - bytes.put_i32(int_array.len() as i32); + let len = int_array.len(); + if len > i32::MAX as usize { + return Err(Error::LargeLength(len)); + } + + w.write_i32_be(len as i32)?; for int in int_array { - bytes.put_i32(*int) + w.write_i32_be(*int)?; } } NbtTag::LongArray(long_array) => { - bytes.put_i32(long_array.len() as i32); + let len = long_array.len(); + if len > i32::MAX as usize { + return Err(Error::LargeLength(len)); + } + + w.write_i32_be(len as i32)?; + for long in long_array { - bytes.put_i64(*long) + w.write_i64_be(*long)?; } } - } - bytes.freeze() + }; + Ok(()) } - pub fn deserialize(bytes: &mut impl Buf) -> Result { - let tag_id = bytes.get_u8(); - Self::deserialize_data(bytes, tag_id) + pub fn deserialize(reader: &mut ReadAdaptor) -> Result + where + R: Read, + { + let tag_id = reader.get_u8_be()?; + Self::deserialize_data(reader, tag_id) } - pub fn deserialize_from_cursor(cursor: &mut Cursor<&[u8]>) -> Result { - Self::deserialize(cursor) + pub fn skip_data(reader: &mut ReadAdaptor, tag_id: u8) -> Result<(), Error> + where + R: Read, + { + match tag_id { + END_ID => Ok(()), + BYTE_ID => reader.skip_bytes(1), + SHORT_ID => reader.skip_bytes(2), + INT_ID => reader.skip_bytes(4), + LONG_ID => reader.skip_bytes(8), + FLOAT_ID => reader.skip_bytes(4), + DOUBLE_ID => reader.skip_bytes(8), + BYTE_ARRAY_ID => { + let len = reader.get_i32_be()?; + if len < 0 { + return Err(Error::NegativeLength(len)); + } + reader.skip_bytes(len as u64) + } + STRING_ID => { + let len = reader.get_u16_be()?; + reader.skip_bytes(len as u64) + } + LIST_ID => { + let tag_type_id = reader.get_u8_be()?; + let len = reader.get_i32_be()?; + if len < 0 { + return Err(Error::NegativeLength(len)); + } + + for _ in 0..len { + Self::skip_data(reader, tag_type_id)?; + } + + Ok(()) + } + COMPOUND_ID => NbtCompound::skip_content(reader), + INT_ARRAY_ID => { + let len = reader.get_i32_be()?; + if len < 0 { + return Err(Error::NegativeLength(len)); + } + + reader.skip_bytes(len as u64 * 4) + } + LONG_ARRAY_ID => { + let len = reader.get_i32_be()?; + if len < 0 { + return Err(Error::NegativeLength(len)); + } + + reader.skip_bytes(len as u64 * 8) + } + _ => Err(Error::UnknownTagId(tag_id)), + } } - pub fn deserialize_data(bytes: &mut impl Buf, tag_id: u8) -> Result { + pub fn deserialize_data(reader: &mut ReadAdaptor, tag_id: u8) -> Result + where + R: Read, + { match tag_id { END_ID => Ok(NbtTag::End), BYTE_ID => { - let byte = bytes.get_i8(); + let byte = reader.get_i8_be()?; Ok(NbtTag::Byte(byte)) } SHORT_ID => { - let short = bytes.get_i16(); + let short = reader.get_i16_be()?; Ok(NbtTag::Short(short)) } INT_ID => { - let int = bytes.get_i32(); + let int = reader.get_i32_be()?; Ok(NbtTag::Int(int)) } LONG_ID => { - let long = bytes.get_i64(); + let long = reader.get_i64_be()?; Ok(NbtTag::Long(long)) } FLOAT_ID => { - let float = bytes.get_f32(); + let float = reader.get_f32_be()?; Ok(NbtTag::Float(float)) } DOUBLE_ID => { - let double = bytes.get_f64(); + let double = reader.get_f64_be()?; Ok(NbtTag::Double(double)) } BYTE_ARRAY_ID => { - let len = bytes.get_i32() as usize; - let byte_array = bytes.copy_to_bytes(len); + let len = reader.get_i32_be()?; + if len < 0 { + return Err(Error::NegativeLength(len)); + } + + let byte_array = reader.read_boxed_slice(len as usize)?; Ok(NbtTag::ByteArray(byte_array)) } - STRING_ID => Ok(NbtTag::String(get_nbt_string(bytes).unwrap())), + STRING_ID => Ok(NbtTag::String(get_nbt_string(reader)?)), LIST_ID => { - let tag_type_id = bytes.get_u8(); - let len = bytes.get_i32(); + let tag_type_id = reader.get_u8_be()?; + let len = reader.get_i32_be()?; + if len < 0 { + return Err(Error::NegativeLength(len)); + } + let mut list = Vec::with_capacity(len as usize); for _ in 0..len { - let tag = NbtTag::deserialize_data(bytes, tag_type_id)?; + let tag = NbtTag::deserialize_data(reader, tag_type_id)?; assert_eq!(tag.get_type_id(), tag_type_id); list.push(tag); } - Ok(NbtTag::List(list)) + Ok(NbtTag::List(list.into_boxed_slice())) } - COMPOUND_ID => Ok(NbtTag::Compound(NbtCompound::deserialize_content(bytes)?)), + COMPOUND_ID => Ok(NbtTag::Compound(NbtCompound::deserialize_content(reader)?)), INT_ARRAY_ID => { - let len = bytes.get_i32() as usize; + let len = reader.get_i32_be()?; + if len < 0 { + return Err(Error::NegativeLength(len)); + } + + let len = len as usize; let mut int_array = Vec::with_capacity(len); for _ in 0..len { - let int = bytes.get_i32(); + let int = reader.get_i32_be()?; int_array.push(int); } - Ok(NbtTag::IntArray(int_array)) + Ok(NbtTag::IntArray(int_array.into_boxed_slice())) } LONG_ARRAY_ID => { - let len = bytes.get_i32() as usize; + let len = reader.get_i32_be()?; + if len < 0 { + return Err(Error::NegativeLength(len)); + } + + let len = len as usize; let mut long_array = Vec::with_capacity(len); for _ in 0..len { - let long = bytes.get_i64(); + let long = reader.get_i64_be()?; long_array.push(long); } - Ok(NbtTag::LongArray(long_array)) + Ok(NbtTag::LongArray(long_array.into_boxed_slice())) } _ => Err(Error::UnknownTagId(tag_id)), } } - pub fn deserialize_data_from_cursor( - cursor: &mut Cursor<&[u8]>, - tag_id: u8, - ) -> Result { - Self::deserialize_data(cursor, tag_id) - } - pub fn extract_byte(&self) -> Option { match self { NbtTag::Byte(byte) => Some(*byte), @@ -214,7 +314,7 @@ impl NbtTag { } } - pub fn extract_byte_array(&self) -> Option { + pub fn extract_byte_array(&self) -> Option> { match self { // Note: Bytes are free to clone, so we can hand out an owned type NbtTag::ByteArray(byte_array) => Some(byte_array.clone()), @@ -229,7 +329,7 @@ impl NbtTag { } } - pub fn extract_list(&self) -> Option<&Vec> { + pub fn extract_list(&self) -> Option<&[NbtTag]> { match self { NbtTag::List(list) => Some(list), _ => None, @@ -243,14 +343,14 @@ impl NbtTag { } } - pub fn extract_int_array(&self) -> Option<&Vec> { + pub fn extract_int_array(&self) -> Option<&[i32]> { match self { NbtTag::IntArray(int_array) => Some(int_array), _ => None, } } - pub fn extract_long_array(&self) -> Option<&Vec> { + pub fn extract_long_array(&self) -> Option<&[i64]> { match self { NbtTag::LongArray(long_array) => Some(long_array), _ => None, @@ -266,7 +366,9 @@ impl From<&str> for NbtTag { impl From<&[u8]> for NbtTag { fn from(value: &[u8]) -> Self { - NbtTag::ByteArray(Bytes::copy_from_slice(value)) + let mut cloned = Vec::with_capacity(value.len()); + cloned.copy_from_slice(value); + NbtTag::ByteArray(cloned.into_boxed_slice()) } } diff --git a/pumpkin-protocol/src/bytebuf/serializer.rs b/pumpkin-protocol/src/bytebuf/serializer.rs index 5682aa7b8..3d3afb55a 100644 --- a/pumpkin-protocol/src/bytebuf/serializer.rs +++ b/pumpkin-protocol/src/bytebuf/serializer.rs @@ -94,13 +94,25 @@ impl ser::Serializer for &mut Serializer { } fn serialize_newtype_struct( self, - _name: &'static str, - _value: &T, + name: &'static str, + value: &T, ) -> Result where T: ?Sized + Serialize, { - unimplemented!() + // TODO: This is super sketchy... is there a way to do it better? Can we choose what + // serializer to use on a struct somehow from within the struct? + if name == "TextComponent" { + let mut buf = Vec::new(); + let mut nbt_serializer = pumpkin_nbt::serializer::Serializer::new(&mut buf, None); + value + .serialize(&mut nbt_serializer) + .expect("Failed to serialize NBT for TextComponent within the network serializer"); + + self.serialize_bytes(&buf) + } else { + value.serialize(self) + } } fn serialize_newtype_variant( self, diff --git a/pumpkin-protocol/src/client/config/registry_data.rs b/pumpkin-protocol/src/client/config/registry_data.rs index eb69e5da1..0733097bd 100644 --- a/pumpkin-protocol/src/client/config/registry_data.rs +++ b/pumpkin-protocol/src/client/config/registry_data.rs @@ -1,6 +1,7 @@ -use bytes::{BufMut, BytesMut}; +use bytes::BufMut; use pumpkin_data::packet::clientbound::CONFIG_REGISTRY_DATA; use pumpkin_macros::client_packet; +use serde::Serialize; use crate::{bytebuf::ByteBufMut, codec::identifier::Identifier, ClientPacket}; @@ -21,7 +22,18 @@ impl<'a> CRegistryData<'a> { pub struct RegistryEntry { pub entry_id: Identifier, - pub data: Option, + pub data: Option>, +} + +impl RegistryEntry { + pub fn from_nbt(name: &str, nbt: &impl Serialize) -> Self { + let mut data_buf = Vec::new(); + pumpkin_nbt::serializer::to_bytes_unnamed(nbt, &mut data_buf).unwrap(); + RegistryEntry { + entry_id: Identifier::vanilla(name), + data: Some(data_buf.into_boxed_slice()), + } + } } impl ClientPacket for CRegistryData<'_> { diff --git a/pumpkin-protocol/src/client/play/block_entity_data.rs b/pumpkin-protocol/src/client/play/block_entity_data.rs index 00007bce5..5d843c4dd 100644 --- a/pumpkin-protocol/src/client/play/block_entity_data.rs +++ b/pumpkin-protocol/src/client/play/block_entity_data.rs @@ -10,11 +10,11 @@ use crate::VarInt; pub struct CBlockEntityData { location: BlockPos, r#type: VarInt, - nbt_data: Vec, + nbt_data: Box<[u8]>, } impl CBlockEntityData { - pub fn new(location: BlockPos, r#type: VarInt, nbt_data: Vec) -> Self { + pub fn new(location: BlockPos, r#type: VarInt, nbt_data: Box<[u8]>) -> Self { Self { location, r#type, diff --git a/pumpkin-protocol/src/client/play/chunk_data.rs b/pumpkin-protocol/src/client/play/chunk_data.rs index 4b265748f..dfe4af28e 100644 --- a/pumpkin-protocol/src/client/play/chunk_data.rs +++ b/pumpkin-protocol/src/client/play/chunk_data.rs @@ -18,7 +18,8 @@ impl ClientPacket for CChunkData<'_> { // Chunk Z buf.put_i32(self.0.position.z); - let heightmap_nbt = pumpkin_nbt::serializer::to_bytes_unnamed(&self.0.heightmap).unwrap(); + let mut heightmap_nbt = Vec::new(); + pumpkin_nbt::serializer::to_bytes_unnamed(&self.0.heightmap, &mut heightmap_nbt).unwrap(); // Heightmaps buf.put_slice(&heightmap_nbt); diff --git a/pumpkin-protocol/src/client/play/update_objectives.rs b/pumpkin-protocol/src/client/play/update_objectives.rs index e74b0e447..b98fdff7b 100644 --- a/pumpkin-protocol/src/client/play/update_objectives.rs +++ b/pumpkin-protocol/src/client/play/update_objectives.rs @@ -47,7 +47,9 @@ impl ClientPacket for CUpdateObjectives<'_> { NumberFormat::Styled(style) => { p.put_var_int(&VarInt(1)); // TODO - p.put_slice(&pumpkin_nbt::serializer::to_bytes_unnamed(style).unwrap()); + let mut style_buf = Vec::new(); + pumpkin_nbt::serializer::to_bytes_unnamed(style, &mut style_buf).unwrap(); + p.put_slice(&style_buf); } NumberFormat::Fixed(text_component) => { p.put_var_int(&VarInt(2)); diff --git a/pumpkin-protocol/src/lib.rs b/pumpkin-protocol/src/lib.rs index 8d6711052..d64ababab 100644 --- a/pumpkin-protocol/src/lib.rs +++ b/pumpkin-protocol/src/lib.rs @@ -82,10 +82,12 @@ pub struct RawPacket { pub bytebuf: Bytes, } +// TODO: Have the input be `impl Write` pub trait ClientPacket: Packet { fn write(&self, bytebuf: &mut impl BufMut); } +// TODO: Have the input be `impl Read` pub trait ServerPacket: Packet + Sized { fn read(bytebuf: &mut impl Buf) -> Result; } diff --git a/pumpkin-registry/src/jukebox_song.rs b/pumpkin-registry/src/jukebox_song.rs index adcb0e440..b554ceb15 100644 --- a/pumpkin-registry/src/jukebox_song.rs +++ b/pumpkin-registry/src/jukebox_song.rs @@ -5,7 +5,7 @@ pub struct JukeboxSong { sound_event: String, description: Description, length_in_seconds: f32, - comparator_output: u32, + comparator_output: i32, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/pumpkin-registry/src/lib.rs b/pumpkin-registry/src/lib.rs index 3ac8faaa8..9b6e9e035 100644 --- a/pumpkin-registry/src/lib.rs +++ b/pumpkin-registry/src/lib.rs @@ -92,10 +92,7 @@ impl Registry { let registry_entries = SYNCED_REGISTRIES .biome .iter() - .map(|s| RegistryEntry { - entry_id: Identifier::vanilla(s.0), - data: Some(pumpkin_nbt::serializer::to_bytes_unnamed(&s.1).unwrap()), - }) + .map(|(name, nbt)| RegistryEntry::from_nbt(name, nbt)) .collect(); let biome = Registry { registry_id: Identifier::vanilla("worldgen/biome"), @@ -105,11 +102,9 @@ impl Registry { let registry_entries = SYNCED_REGISTRIES .chat_type .iter() - .map(|s| RegistryEntry { - entry_id: Identifier::vanilla(s.0), - data: Some(pumpkin_nbt::serializer::to_bytes_unnamed(&s.1).unwrap()), - }) + .map(|(name, nbt)| RegistryEntry::from_nbt(name, nbt)) .collect(); + let chat_type = Registry { registry_id: Identifier::vanilla("chat_type"), registry_entries, @@ -144,13 +139,7 @@ impl Registry { let registry_entries = SYNCED_REGISTRIES .wolf_variant .iter() - .map(|s| { - let variant = s.1.clone(); - RegistryEntry { - entry_id: Identifier::vanilla(s.0), - data: Some(pumpkin_nbt::serializer::to_bytes_unnamed(&variant).unwrap()), - } - }) + .map(|(name, nbt)| RegistryEntry::from_nbt(name, nbt)) .collect(); let wolf_variant = Registry { registry_id: Identifier::vanilla("wolf_variant"), @@ -160,10 +149,7 @@ impl Registry { let registry_entries = SYNCED_REGISTRIES .painting_variant .iter() - .map(|s| RegistryEntry { - entry_id: Identifier::vanilla(s.0), - data: Some(pumpkin_nbt::serializer::to_bytes_unnamed(&s.1).unwrap()), - }) + .map(|(name, nbt)| RegistryEntry::from_nbt(name, nbt)) .collect(); let painting_variant = Registry { registry_id: Identifier::vanilla("painting_variant"), @@ -173,10 +159,7 @@ impl Registry { let registry_entries = SYNCED_REGISTRIES .dimension_type .iter() - .map(|s| RegistryEntry { - entry_id: Identifier::vanilla(s.0), - data: Some(pumpkin_nbt::serializer::to_bytes_unnamed(&s.1).unwrap()), - }) + .map(|(name, nbt)| RegistryEntry::from_nbt(name, nbt)) .collect(); let dimension_type = Registry { registry_id: Identifier::vanilla("dimension_type"), @@ -186,10 +169,7 @@ impl Registry { let registry_entries = SYNCED_REGISTRIES .damage_type .iter() - .map(|s| RegistryEntry { - entry_id: Identifier::vanilla(s.0), - data: Some(pumpkin_nbt::serializer::to_bytes_unnamed(&s.1).unwrap()), - }) + .map(|(name, nbt)| RegistryEntry::from_nbt(name, nbt)) .collect(); let damage_type = Registry { registry_id: Identifier::vanilla("damage_type"), @@ -199,10 +179,7 @@ impl Registry { let registry_entries = SYNCED_REGISTRIES .banner_pattern .iter() - .map(|s| RegistryEntry { - entry_id: Identifier::vanilla(s.0), - data: Some(pumpkin_nbt::serializer::to_bytes_unnamed(&s.1).unwrap()), - }) + .map(|(name, nbt)| RegistryEntry::from_nbt(name, nbt)) .collect(); let banner_pattern = Registry { registry_id: Identifier::vanilla("banner_pattern"), @@ -226,10 +203,7 @@ impl Registry { let registry_entries = SYNCED_REGISTRIES .jukebox_song .iter() - .map(|s| RegistryEntry { - entry_id: Identifier::vanilla(s.0), - data: Some(pumpkin_nbt::serializer::to_bytes_unnamed(&s.1).unwrap()), - }) + .map(|(name, nbt)| RegistryEntry::from_nbt(name, nbt)) .collect(); let jukebox_song = Registry { registry_id: Identifier::vanilla("jukebox_song"), diff --git a/pumpkin-util/src/text/mod.rs b/pumpkin-util/src/text/mod.rs index 6a6def038..0141e7483 100644 --- a/pumpkin-util/src/text/mod.rs +++ b/pumpkin-util/src/text/mod.rs @@ -15,8 +15,7 @@ pub mod hover; pub mod style; /// Represents a Text component -#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Hash)] -#[serde(rename_all = "camelCase")] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct TextComponent(pub TextComponentBase); #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] @@ -153,27 +152,14 @@ impl TextComponent { } } -impl serde::Serialize for TextComponent { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_bytes(&self.encode()) - } -} - -impl pumpkin_nbt::serializer::SerializeChild for TextComponent { - fn serialize_child(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - self.0.serialize(serializer) - } -} - impl TextComponent { - pub fn encode(&self) -> bytes::BytesMut { - pumpkin_nbt::serializer::to_bytes_text_component(self).unwrap() + pub fn encode(&self) -> Box<[u8]> { + let mut buf = Vec::new(); + // TODO: Properly handle errors + pumpkin_nbt::serializer::to_bytes_unnamed(&self.0, &mut buf) + .expect("Failed to serialize text component NBT for encode"); + + buf.into_boxed_slice() } pub fn color(mut self, color: Color) -> Self { @@ -274,3 +260,33 @@ pub enum TextContent { /// https://minecraft.wiki/w/Controls#Configurable_controls Keybind { keybind: Cow<'static, str> }, } + +#[cfg(test)] +mod test { + use pumpkin_nbt::serializer::to_bytes_unnamed; + + use crate::text::{color::NamedColor, TextComponent}; + + #[test] + fn test_serialize_text_component() { + let msg_comp = TextComponent::translate( + "multiplayer.player.joined", + [TextComponent::text("NAME".to_string())].into(), + ) + .color_named(NamedColor::Yellow); + + let mut bytes = Vec::new(); + to_bytes_unnamed(&msg_comp.0, &mut bytes).unwrap(); + + let expected_bytes = [ + 0x0A, 0x08, 0x00, 0x09, 0x74, 0x72, 0x61, 0x6E, 0x73, 0x6C, 0x61, 0x74, 0x65, 0x00, + 0x19, 0x6D, 0x75, 0x6C, 0x74, 0x69, 0x70, 0x6C, 0x61, 0x79, 0x65, 0x72, 0x2E, 0x70, + 0x6C, 0x61, 0x79, 0x65, 0x72, 0x2E, 0x6A, 0x6F, 0x69, 0x6E, 0x65, 0x64, 0x09, 0x00, + 0x04, 0x77, 0x69, 0x74, 0x68, 0x0A, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x04, 0x74, + 0x65, 0x78, 0x74, 0x00, 0x04, 0x4E, 0x41, 0x4D, 0x45, 0x00, 0x08, 0x00, 0x05, 0x63, + 0x6F, 0x6C, 0x6F, 0x72, 0x00, 0x06, 0x79, 0x65, 0x6C, 0x6C, 0x6F, 0x77, 0x00, + ]; + + assert_eq!(bytes, expected_bytes); + } +} diff --git a/pumpkin-world/Cargo.toml b/pumpkin-world/Cargo.toml index 7ab39505e..387a60459 100644 --- a/pumpkin-world/Cargo.toml +++ b/pumpkin-world/Cargo.toml @@ -38,8 +38,6 @@ indexmap = "2.7" enum_dispatch = "0.3" -fastnbt = { git = "https://github.com/owengage/fastnbt.git" } - noise = "0.9" rand = "0.8" @@ -49,6 +47,7 @@ serde_json5 = { git = "https://github.com/kralverde/serde_json5.git" } derive-getters = "0.5.0" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } +temp-dir = "0.1.14" [[bench]] name = "chunk_noise_populate" diff --git a/pumpkin-world/assets/level.dat b/pumpkin-world/assets/level.dat new file mode 100644 index 000000000..6b3de1da6 Binary files /dev/null and b/pumpkin-world/assets/level.dat differ diff --git a/pumpkin-world/src/chunk/anvil.rs b/pumpkin-world/src/chunk/anvil.rs index 43c853c70..dd9070da5 100644 --- a/pumpkin-world/src/chunk/anvil.rs +++ b/pumpkin-world/src/chunk/anvil.rs @@ -1,8 +1,8 @@ use bytes::*; -use fastnbt::LongArray; use flate2::bufread::{GzDecoder, GzEncoder, ZlibDecoder, ZlibEncoder}; use indexmap::IndexMap; use pumpkin_config::ADVANCED_CONFIG; +use pumpkin_nbt::serializer::to_bytes; use pumpkin_util::math::ceil_log2; use std::time::{SystemTime, UNIX_EPOCH}; use std::{ @@ -425,7 +425,7 @@ impl AnvilChunkFormat { sections.push(ChunkSection { y: i as i8 - 4, block_states: Some(ChunkSectionBlockStates { - data: Some(LongArray::new(section_longs)), + data: Some(section_longs.into_boxed_slice()), palette: palette .into_iter() .map(|entry| PaletteEntry { @@ -446,7 +446,9 @@ impl AnvilChunkFormat { sections, }; - fastnbt::to_bytes(&nbt).map_err(ChunkSerializingError::ErrorSerializingChunk) + let mut result = Vec::new(); + to_bytes(&nbt, &mut result).map_err(ChunkSerializingError::ErrorSerializingChunk)?; + Ok(result) } /// Returns the next free writable sector @@ -565,4 +567,47 @@ mod tests { println!("Checked chunks successfully"); } + + // TODO + /* + #[test] + fn test_load_java_chunk() { + let temp_dir = TempDir::new().unwrap(); + let level_folder = LevelFolder { + root_folder: temp_dir.path().to_path_buf(), + region_folder: temp_dir.path().join("region"), + }; + + fs::create_dir(&level_folder.region_folder).unwrap(); + fs::copy( + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .join(file!()) + .parent() + .unwrap() + .join("../../assets/r.0.0.mca"), + level_folder.region_folder.join("r.0.0.mca"), + ) + .unwrap(); + + let mut actually_tested = false; + for x in 0..(1 << 5) { + for z in 0..(1 << 5) { + let result = AnvilChunkFormat {}.read_chunk(&level_folder, &Vector2 { x, z }); + + match result { + Ok(_) => actually_tested = true, + Err(ChunkReadingError::ParsingError(ChunkParsingError::ChunkNotGenerated)) => {} + Err(ChunkReadingError::ChunkNotExist) => {} + Err(e) => panic!("{:?}", e), + } + + println!("=========== OK ==========="); + } + } + + assert!(actually_tested); + } + */ } diff --git a/pumpkin-world/src/chunk/mod.rs b/pumpkin-world/src/chunk/mod.rs index 7c2b1abbe..6b2143a84 100644 --- a/pumpkin-world/src/chunk/mod.rs +++ b/pumpkin-world/src/chunk/mod.rs @@ -2,8 +2,8 @@ use dashmap::{ mapref::one::{Ref, RefMut}, DashMap, }; -use fastnbt::LongArray; use pumpkin_data::chunk::ChunkStatus; +use pumpkin_nbt::{deserializer::from_bytes, nbt_long_array}; use pumpkin_util::math::{ceil_log2, vector2::Vector2}; use serde::{Deserialize, Serialize}; use std::{ @@ -161,29 +161,34 @@ pub enum Subchunk { struct PaletteEntry { // block name name: String, + #[serde(skip_serializing_if = "Option::is_none")] properties: Option>, } #[derive(Deserialize, Serialize, Debug, Clone)] #[serde(rename_all = "UPPERCASE")] pub struct ChunkHeightmaps { - // #[serde(with = "LongArray")] - motion_blocking: LongArray, - // #[serde(with = "LongArray")] - world_surface: LongArray, + #[serde(serialize_with = "nbt_long_array")] + motion_blocking: Box<[i64]>, + #[serde(serialize_with = "nbt_long_array")] + world_surface: Box<[i64]>, } #[derive(Serialize, Deserialize, Debug)] struct ChunkSection { #[serde(rename = "Y")] y: i8, + #[serde(skip_serializing_if = "Option::is_none")] block_states: Option, } #[derive(Serialize, Deserialize, Debug, Clone)] struct ChunkSectionBlockStates { - // #[serde(with = "LongArray")] - data: Option, + #[serde( + serialize_with = "nbt_long_array", + skip_serializing_if = "Option::is_none" + )] + data: Option>, palette: Vec, } @@ -234,8 +239,8 @@ impl Default for ChunkHeightmaps { fn default() -> Self { Self { // 0 packed into an i64 7 times. - motion_blocking: LongArray::new(vec![0; 37]), - world_surface: LongArray::new(vec![0; 37]), + motion_blocking: vec![0; 37].into_boxed_slice(), + world_surface: vec![0; 37].into_boxed_slice(), } } } @@ -402,15 +407,15 @@ impl ChunkData { chunk_data: &[u8], position: Vector2, ) -> Result { - if fastnbt::from_bytes::(chunk_data) - .map_err(|_| ChunkParsingError::FailedReadStatus)? + if from_bytes::(chunk_data) + .map_err(ChunkParsingError::FailedReadStatus)? .status != ChunkStatus::Full { return Err(ChunkParsingError::ChunkNotGenerated); } - let chunk_data = fastnbt::from_bytes::(chunk_data) + let chunk_data = from_bytes::(chunk_data) .map_err(|e| ChunkParsingError::ErrorDeserializingChunk(e.to_string()))?; if chunk_data.x_pos != position.x || chunk_data.z_pos != position.z { @@ -502,8 +507,8 @@ impl ChunkData { #[derive(Error, Debug)] pub enum ChunkParsingError { - #[error("Failed reading chunk status")] - FailedReadStatus, + #[error("Failed reading chunk status {0}")] + FailedReadStatus(pumpkin_nbt::Error), #[error("The chunk isn't generated yet")] ChunkNotGenerated, #[error("Error deserializing chunk: {0}")] @@ -517,5 +522,5 @@ fn convert_index(index: ChunkRelativeBlockCoordinates) -> usize { #[derive(Error, Debug)] pub enum ChunkSerializingError { #[error("Error serializing chunk: {0}")] - ErrorSerializingChunk(fastnbt::error::Error), + ErrorSerializingChunk(pumpkin_nbt::Error), } diff --git a/pumpkin-world/src/world_info/anvil.rs b/pumpkin-world/src/world_info/anvil.rs index d02b32e27..d0a5c1673 100644 --- a/pumpkin-world/src/world_info/anvil.rs +++ b/pumpkin-world/src/world_info/anvil.rs @@ -1,10 +1,10 @@ use std::{ fs::OpenOptions, - io::{Read, Write}, time::{SystemTime, UNIX_EPOCH}, }; use flate2::{read::GzDecoder, write::GzEncoder, Compression}; +use pumpkin_nbt::{deserializer::from_bytes, serializer::to_bytes}; use serde::{Deserialize, Serialize}; use crate::level::LevelFolder; @@ -19,20 +19,12 @@ impl WorldInfoReader for AnvilLevelInfo { fn read_world_info(&self, level_folder: &LevelFolder) -> Result { let path = level_folder.root_folder.join(LEVEL_DAT_FILE_NAME); - let mut world_info_file = OpenOptions::new().read(true).open(path)?; - - let mut buffer = Vec::new(); - world_info_file.read_to_end(&mut buffer)?; - - // try to decompress using GZip - let mut decoder = GzDecoder::new(&buffer[..]); - let mut decompressed_data = Vec::new(); - decoder.read_to_end(&mut decompressed_data)?; - - let info = fastnbt::from_bytes::(&decompressed_data) + let world_info_file = OpenOptions::new().read(true).open(path)?; + let compression_reader = GzDecoder::new(world_info_file); + let info = from_bytes::(compression_reader) .map_err(|e| WorldInfoError::DeserializationError(e.to_string()))?; - // todo check version + // TODO: check version Ok(info.data) } @@ -48,46 +40,134 @@ impl WorldInfoWriter for AnvilLevelInfo { let since_the_epoch = start .duration_since(UNIX_EPOCH) .expect("Time went backwards"); - let level = LevelDat { - data: LevelData { - allow_commands: info.allow_commands, - data_version: info.data_version, - difficulty: info.difficulty, - world_gen_settings: info.world_gen_settings, - last_played: since_the_epoch.as_millis() as i64, - level_name: info.level_name, - spawn_x: info.spawn_x, - spawn_y: info.spawn_y, - spawn_z: info.spawn_z, - spawn_angle: info.spawn_angle, - nbt_version: info.nbt_version, - version: info.version, - }, - }; - // convert it into nbt - let nbt = pumpkin_nbt::serializer::to_bytes_unnamed(&level).unwrap(); - // now compress using GZip, TODO: im not sure about the to_vec, but writer is not implemented for BytesMut, see https://github.com/tokio-rs/bytes/pull/478 - let mut encoder = GzEncoder::new(nbt.to_vec(), Compression::best()); - let compressed_data = Vec::new(); - encoder.write_all(&compressed_data)?; + let mut level_data = info.clone(); + level_data.last_played = since_the_epoch.as_millis() as i64; + let level = LevelDat { data: level_data }; // open file let path = level_folder.root_folder.join(LEVEL_DAT_FILE_NAME); - let mut world_info_file = OpenOptions::new() + let world_info_file = OpenOptions::new() .truncate(true) .create(true) .write(true) .open(path)?; - // write compressed data into file - world_info_file.write_all(&compressed_data).unwrap(); + // write compressed data into file + let compression_writer = GzEncoder::new(world_info_file, Compression::best()); + // TODO: Proper error handling + to_bytes(&level, compression_writer).expect("Failed to write level.dat to disk"); Ok(()) } } -#[derive(Serialize, Deserialize)] +#[derive(Debug, PartialEq, Serialize, Deserialize)] pub struct LevelDat { // This tag contains all the level data. #[serde(rename = "Data")] pub data: LevelData, } + +#[cfg(test)] +mod test { + + use std::sync::LazyLock; + + use flate2::read::GzDecoder; + use pumpkin_nbt::{deserializer::from_bytes, serializer::to_bytes}; + use temp_dir::TempDir; + + use crate::{ + level::LevelFolder, + world_info::{DataPacks, LevelData, WorldGenSettings, WorldVersion}, + }; + + use super::{AnvilLevelInfo, LevelDat, WorldInfoReader, WorldInfoWriter}; + + #[test] + fn test_preserve_level_dat_seed() { + let seed = 1337; + + let mut data = LevelData::default(); + data.world_gen_settings.seed = seed; + + let temp_dir = TempDir::new().unwrap(); + let level_folder = LevelFolder { + root_folder: temp_dir.path().to_path_buf(), + region_folder: temp_dir.path().join("region"), + }; + + AnvilLevelInfo + .write_world_info(data, &level_folder) + .unwrap(); + + let data = AnvilLevelInfo.read_world_info(&level_folder).unwrap(); + + assert_eq!(data.world_gen_settings.seed, seed); + } + + static LEVEL_DAT: LazyLock = LazyLock::new(|| LevelDat { + data: LevelData { + allow_commands: true, + border_center_x: 0.0, + border_center_z: 0.0, + border_damage_per_block: 0.2, + border_size: 59_999_968.0, + border_safe_zone: 5.0, + border_size_lerp_target: 59_999_968.0, + border_size_lerp_time: 0, + border_warning_blocks: 5.0, + border_warning_time: 15.0, + clear_weather_time: 0, + data_packs: DataPacks { + disabled: vec![ + "minecart_improvements".to_string(), + "redstone_experiments".to_string(), + "trade_rebalance".to_string(), + ], + enabled: vec!["vanilla".to_string()], + }, + data_version: 4189, + day_time: 1727, + difficulty: 2, + difficulty_locked: false, + world_gen_settings: WorldGenSettings { seed: 1 }, + last_played: 1733847709327, + level_name: "New World".to_string(), + spawn_x: 160, + spawn_y: 70, + spawn_z: 160, + spawn_angle: 0.0, + nbt_version: 19133, + version: WorldVersion { + name: "1.21.4".to_string(), + id: 4189, + snapshot: false, + series: "main".to_string(), + }, + }, + }); + + #[test] + fn test_deserialize_level_dat() { + let raw_compressed_nbt = include_bytes!("../../assets/level.dat"); + assert!(!raw_compressed_nbt.is_empty()); + + let decoder = GzDecoder::new(&raw_compressed_nbt[..]); + let level_dat: LevelDat = from_bytes(decoder).expect("Failed to decode from file"); + + assert_eq!(level_dat, *LEVEL_DAT); + } + + #[test] + fn test_serialize_level_dat() { + let mut serialized = Vec::new(); + to_bytes(&*LEVEL_DAT, &mut serialized).expect("Failed to encode to bytes"); + + assert!(!serialized.is_empty()); + + let level_dat_again: LevelDat = + from_bytes(&serialized[..]).expect("Failed to decode from bytes"); + + assert_eq!(level_dat_again, *LEVEL_DAT); + } +} diff --git a/pumpkin-world/src/world_info/mod.rs b/pumpkin-world/src/world_info/mod.rs index 0c31abdde..de6217f86 100644 --- a/pumpkin-world/src/world_info/mod.rs +++ b/pumpkin-world/src/world_info/mod.rs @@ -19,16 +19,47 @@ pub(crate) trait WorldInfoWriter: Sync + Send { ) -> Result<(), WorldInfoError>; } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] #[serde(rename_all = "PascalCase")] -#[serde(default)] pub struct LevelData { // true if cheats are enabled. + #[serde(rename = "allowCommands")] pub allow_commands: bool, + // Center of the world border on the X coordinate. Defaults to 0. + pub border_center_x: f64, + // Center of the world border on the Z coordinate. Defaults to 0. + pub border_center_z: f64, + // Defaults to 0.2. + pub border_damage_per_block: f64, + // Width and length of the border of the border. Defaults to 60000000. + pub border_size: f64, + // Defaults to 5. + pub border_safe_zone: f64, + // Defaults to 60000000. + pub border_size_lerp_target: f64, + // Defaults to 0. + pub border_size_lerp_time: i64, + // Defaults to 5. + pub border_warning_blocks: f64, + // Defaults to 15. + pub border_warning_time: f64, + // The number of ticks until "clear weather" has ended. + #[serde(rename = "clearWeatherTime")] + pub clear_weather_time: i32, + // TODO: Custom Boss Events + + // Options for data packs. + pub data_packs: DataPacks, // An integer displaying the data version. pub data_version: i32, + // The time of day. 0 is sunrise, 6000 is mid day, 12000 is sunset, 18000 is mid night, 24000 is the next day's 0. This value keeps counting past 24000 and does not reset to 0. + pub day_time: i64, // The current difficulty setting. - pub difficulty: u8, + pub difficulty: i8, + // 1 or 0 (true/false) - True if the difficulty has been locked. Defaults to 0. + pub difficulty_locked: bool, + // TODO: DimensionData + // the generation settings for each dimension. pub world_gen_settings: WorldGenSettings, // The Unix time in milliseconds when the level was last loaded. @@ -51,12 +82,21 @@ pub struct LevelData { // TODO: Implement the rest of the fields } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] pub struct WorldGenSettings { // the numerical seed of the world pub seed: i64, } +#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] +#[serde(rename_all = "PascalCase")] +pub struct DataPacks { + // List of disabled data packs. + pub disabled: Vec, + // List of enabled data packs. By default, this is populated with a single string "vanilla". + pub enabled: Vec, +} + fn get_or_create_seed() -> Seed { // TODO: if there is a seed in the config (!= 0) use it. Otherwise make a random one Seed::from(BASIC_CONFIG.seed.as_str()) @@ -70,7 +110,7 @@ impl Default for WorldGenSettings { } } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] #[serde(rename_all = "PascalCase")] pub struct WorldVersion { // The version name as a string, e.g. "15w32b". @@ -98,9 +138,24 @@ impl Default for LevelData { fn default() -> Self { Self { allow_commands: true, - // TODO + border_center_x: 0.0, + border_center_z: 0.0, + border_damage_per_block: 0.2, + border_size: 60_000_000.0, + border_safe_zone: 5.0, + border_size_lerp_target: 60_000_000.0, + border_size_lerp_time: 0, + border_warning_blocks: 5.0, + border_warning_time: 15.0, + clear_weather_time: -1, + data_packs: DataPacks { + disabled: vec![], + enabled: vec!["vanilla".to_string()], + }, data_version: -1, - difficulty: Difficulty::Normal as u8, + day_time: 0, + difficulty: Difficulty::Normal as i8, + difficulty_locked: false, world_gen_settings: Default::default(), last_played: -1, level_name: "world".to_string(), diff --git a/pumpkin/src/entity/mod.rs b/pumpkin/src/entity/mod.rs index 61237c41b..9533c180b 100644 --- a/pumpkin/src/entity/mod.rs +++ b/pumpkin/src/entity/mod.rs @@ -382,24 +382,20 @@ impl NBTStorage for Entity { let position = self.pos.load(); nbt.put( "Pos", - NbtTag::List(vec![ - position.x.into(), - position.y.into(), - position.z.into(), - ]), + NbtTag::List( + vec![position.x.into(), position.y.into(), position.z.into()].into_boxed_slice(), + ), ); let velocity = self.velocity.load(); nbt.put( "Motion", - NbtTag::List(vec![ - velocity.x.into(), - velocity.y.into(), - velocity.z.into(), - ]), + NbtTag::List( + vec![velocity.x.into(), velocity.y.into(), velocity.z.into()].into_boxed_slice(), + ), ); nbt.put( "Rotation", - NbtTag::List(vec![self.yaw.load().into(), self.pitch.load().into()]), + NbtTag::List(vec![self.yaw.load().into(), self.pitch.load().into()].into_boxed_slice()), ); // todo more... diff --git a/pumpkin/src/net/packet/play.rs b/pumpkin/src/net/packet/play.rs index 7ee7a00d4..40a5e54fe 100644 --- a/pumpkin/src/net/packet/play.rs +++ b/pumpkin/src/net/packet/play.rs @@ -1065,13 +1065,13 @@ impl Player { ], ); + let mut sign_buf = Vec::new(); + pumpkin_nbt::serializer::to_bytes_unnamed(&updated_sign, &mut sign_buf).unwrap(); world .broadcast_packet_all(&CBlockEntityData::new( sign_data.location, VarInt(block_entity!("sign") as i32), - pumpkin_nbt::serializer::to_bytes_unnamed(&updated_sign) - .unwrap() - .to_vec(), + sign_buf.into_boxed_slice(), )) .await; }