diff --git a/twilight-model/src/gateway/payload/incoming/guild_create.rs b/twilight-model/src/gateway/payload/incoming/guild_create.rs index 348d647a6dd..82391c736f7 100644 --- a/twilight-model/src/gateway/payload/incoming/guild_create.rs +++ b/twilight-model/src/gateway/payload/incoming/guild_create.rs @@ -4,11 +4,13 @@ use crate::{ }; use serde::{Deserialize, Serialize}; +// Developer note: Do not change order as we want unavailable to fail +// first. #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] #[serde(untagged)] pub enum GuildCreate { - Available(Guild), Unavailable(UnavailableGuild), + Available(Guild), } impl GuildCreate { diff --git a/twilight-model/src/guild/unavailable_guild.rs b/twilight-model/src/guild/unavailable_guild.rs index f4689e53e3c..86a8db4f7c9 100644 --- a/twilight-model/src/guild/unavailable_guild.rs +++ b/twilight-model/src/guild/unavailable_guild.rs @@ -1,12 +1,37 @@ -use crate::id::{marker::GuildMarker, Id}; +use crate::{ + id::{marker::GuildMarker, Id}, + util::mustbe::MustBeBool, +}; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize)] pub struct UnavailableGuild { pub id: Id, pub unavailable: bool, } +impl<'de> Deserialize<'de> for UnavailableGuild { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(rename = "UnavailableGuild")] // Tests expect this struct name + struct UnavailableGuildIntermediate { + id: Id, + #[allow(unused)] // Only used in the derived impl + unavailable: MustBeBool, + } + + let intermediate = UnavailableGuildIntermediate::deserialize(deserializer)?; + + Ok(Self { + id: intermediate.id, + unavailable: true, + }) + } +} + #[cfg(test)] mod tests { use super::UnavailableGuild; diff --git a/twilight-model/src/util/mod.rs b/twilight-model/src/util/mod.rs index c30df90e120..7a013cfb727 100644 --- a/twilight-model/src/util/mod.rs +++ b/twilight-model/src/util/mod.rs @@ -3,6 +3,7 @@ pub mod datetime; pub mod hex_color; pub mod image_hash; +pub(crate) mod mustbe; pub use self::{datetime::Timestamp, hex_color::HexColor, image_hash::ImageHash}; diff --git a/twilight-model/src/util/mustbe.rs b/twilight-model/src/util/mustbe.rs new file mode 100644 index 00000000000..f9b65632b26 --- /dev/null +++ b/twilight-model/src/util/mustbe.rs @@ -0,0 +1,84 @@ +//! A struct that only deserializes from one specific boolean value. +//! +//! This module is heavily based upon +//! . + +use std::fmt; + +use serde::{ + de::{Error, Unexpected, Visitor}, + Deserialize, +}; + +/// Struct that will only serialize from the bool specified as `T`. +pub struct MustBeBool; + +impl<'de, const T: bool> Deserialize<'de> for MustBeBool { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct MustBeBoolVisitor(bool); + + impl<'de> Visitor<'de> for MustBeBoolVisitor { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "boolean `{}`", self.0) + } + + fn visit_bool(self, v: bool) -> Result + where + E: Error, + { + if v == self.0 { + Ok(()) + } else { + Err(E::invalid_value(Unexpected::Bool(v), &self)) + } + } + } + + deserializer + .deserialize_any(MustBeBoolVisitor(T)) + .map(|()| MustBeBool) + } +} + +#[cfg(test)] +mod tests { + use super::MustBeBool; + + use serde::Deserialize; + + #[derive(Deserialize)] + struct MTrue { + #[allow(unused)] + m: MustBeBool, + } + + #[derive(Deserialize)] + struct MFalse { + #[allow(unused)] + m: MustBeBool, + } + + #[derive(Deserialize)] + #[serde(untagged)] + enum TestEnum { + VariantTrue(MTrue), + VariantFalse(MFalse), + } + + #[test] + #[allow(unused)] + fn true_false_enum() { + let json_1 = r#"{ "m": false }"#; + let result_1 = serde_json::from_str::(json_1).unwrap(); + assert!(matches!(result_1, TestEnum::VariantFalse(_))); + + let json_2 = r#"{ "m": true }"#; + let result_2 = serde_json::from_str::(json_2).unwrap(); + assert!(matches!(result_2, TestEnum::VariantTrue(_))); + } +}