diff --git a/crates/binjs_generate_library/src/lib.rs b/crates/binjs_generate_library/src/lib.rs index 890b1b332..7f6a601cb 100644 --- a/crates/binjs_generate_library/src/lib.rs +++ b/crates/binjs_generate_library/src/lib.rs @@ -119,6 +119,13 @@ use io::*; use std::convert::{ From }; +/// Dummy single-variant enum to deserialize only a string `type` and fail otherwise. +#[derive(Deserialize)] +enum TypeKey { + #[serde(rename = \"type\")] + TypeKey, +} + "); // Buffer used to generate the generic data structure (struct declaration). @@ -379,7 +386,7 @@ impl<'a> Walker<'a> for {name} where Self: 'a {{ let definition = format!( " /// Implementation of interface sum {node_name} -#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)] +#[derive(PartialEq, Debug, Clone, Serialize)] #[serde(tag = \"type\")] pub enum {name} {{ {contents} @@ -387,6 +394,41 @@ pub enum {name} {{ BinASTStolen, }}\n +// An optimised implementation of tagged deserialise that expects `type` to be the first key in the object. +// +// This does not strictly adhere to JSON spec, but gives ~2.5x better performance than generic +// deserialistaion with arbitrary ordering. +// +// See https://github.com/serde-rs/serde/issues/1495 for details. +impl<'de> serde::Deserialize<'de> for {name} {{ + fn deserialize>(de: D) -> Result {{ + #[derive(Deserialize)] + enum VariantTag {{ + {variant_tags} + }} + + struct MapVisitor; + + impl<'de> serde::de::Visitor<'de> for MapVisitor {{ + type Value = {name}; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {{ + f.write_str(\"an object\") + }} + + fn visit_map>(self, mut map: A) -> Result<{name}, A::Error> {{ + let (_, variant): (TypeKey, VariantTag) = map.next_entry()?.ok_or_else(|| serde::de::Error::invalid_length(0, &\"1 or more items\"))?; + let de = serde::de::value::MapAccessDeserializer::new(map); + match variant {{ + {value_variants} + }} + }} + }} + + de.deserialize_map(MapVisitor) + }} +}} + /// A mechanism to view value as an instance of interface sum {node_name} /// /// Used to perform shallow cast between larger sums and smaller sums. @@ -408,6 +450,21 @@ pub enum ViewMut{name}<'a> {{\n{ref_mut_contents}\n}}\n", name = case.to_class_cases() )) .format(",\n"), + variant_tags = types + .iter() + .map(|case| format!( + " {name}", + name = case.to_class_cases() + )) + .format(",\n"), + value_variants = types + .iter() + .map(|case| format!( + " VariantTag::{case} => serde::Deserialize::deserialize(de).map({name}::{case})", + name = name, + case = case.to_class_cases() + )) + .format(",\n"), ); let single_variant_from = format!( diff --git a/src/source/shift.rs b/src/source/shift.rs index 5c726f203..e89d0ed29 100644 --- a/src/source/shift.rs +++ b/src/source/shift.rs @@ -95,30 +95,30 @@ impl Script { I: ?Sized + serde::Serialize, O: serde::de::DeserializeOwned, { - let output = (move || { + let output = { let mut io = self.0.lock().unwrap(); - serde_json::to_writer(&mut io.input, input)?; - writeln!(io.input)?; - io.output.next().unwrap() - })() - .map_err(Error::IOError)?; - let mut deserializer = serde_json::Deserializer::from_str(&output); + serde_json::to_writer(&mut io.input, input).map_err(Error::JSONError)?; + writeln!(io.input).map_err(Error::IOError)?; - deserializer.disable_recursion_limit(); + io.output.next().unwrap().map_err(Error::IOError)? + }; #[derive(Deserialize)] #[serde(tag = "type", content = "value")] - enum CustomResult { + #[serde(remote = "std::result::Result")] + enum Result { Ok(T), - Err(String), + Err(E), } - match CustomResult::deserialize(&mut deserializer) { - Ok(CustomResult::Ok(v)) => Ok(v), - Ok(CustomResult::Err(msg)) => Err(Error::ParsingError(msg)), - Err(err) => Err(Error::JSONError(err)), - } + let mut deserializer = serde_json::Deserializer::from_str(&output); + + deserializer.disable_recursion_limit(); + + Result::deserialize(&mut deserializer) + .map_err(Error::JSONError)? + .map_err(Error::ParsingError) } }