Skip to content

Commit

Permalink
fix(model): Unavailable guild must always have unavailable as true (#…
Browse files Browse the repository at this point in the history
…2361)

This adds a new construct that can be used to only deserialize a
specific boolean value. We then use that value to ensure that only
guilds which are actually unavailable get deserialized as such.

---------

Signed-off-by: Erk <[email protected]>
Co-authored-by: Jens Reidel <[email protected]>
  • Loading branch information
Erk- and Gelbpunkt authored Sep 8, 2024
1 parent f75a915 commit f58b540
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 3 deletions.
4 changes: 3 additions & 1 deletion twilight-model/src/gateway/payload/incoming/guild_create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ use crate::{
};
use serde::{Deserialize, Serialize};

// Developer note: Do not change order as we want unavailable to fail
// first.
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(untagged)]
pub enum GuildCreate {
Available(Guild),
Unavailable(UnavailableGuild),
Available(Guild),
}

impl GuildCreate {
Expand Down
29 changes: 27 additions & 2 deletions twilight-model/src/guild/unavailable_guild.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,37 @@
use crate::id::{marker::GuildMarker, Id};
use crate::{
id::{marker::GuildMarker, Id},
util::mustbe::MustBeBool,
};
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize)]
pub struct UnavailableGuild {
pub id: Id<GuildMarker>,
pub unavailable: bool,
}

impl<'de> Deserialize<'de> for UnavailableGuild {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(rename = "UnavailableGuild")] // Tests expect this struct name
struct UnavailableGuildIntermediate {
id: Id<GuildMarker>,
#[allow(unused)] // Only used in the derived impl
unavailable: MustBeBool<true>,
}

let intermediate = UnavailableGuildIntermediate::deserialize(deserializer)?;

Ok(Self {
id: intermediate.id,
unavailable: true,
})
}
}

#[cfg(test)]
mod tests {
use super::UnavailableGuild;
Expand Down
1 change: 1 addition & 0 deletions twilight-model/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
pub mod datetime;
pub mod hex_color;
pub mod image_hash;
pub(crate) mod mustbe;

pub use self::{datetime::Timestamp, hex_color::HexColor, image_hash::ImageHash};

Expand Down
84 changes: 84 additions & 0 deletions twilight-model/src/util/mustbe.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
//! A struct that only deserializes from one specific boolean value.
//!
//! This module is heavily based upon
//! <https://github.com/dtolnay/monostate>.

use std::fmt;

use serde::{
de::{Error, Unexpected, Visitor},
Deserialize,
};

/// Struct that will only serialize from the bool specified as `T`.
pub struct MustBeBool<const T: bool>;

impl<'de, const T: bool> Deserialize<'de> for MustBeBool<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct MustBeBoolVisitor(bool);

impl<'de> Visitor<'de> for MustBeBoolVisitor {
type Value = ();

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "boolean `{}`", self.0)
}

fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
where
E: Error,
{
if v == self.0 {
Ok(())
} else {
Err(E::invalid_value(Unexpected::Bool(v), &self))
}
}
}

deserializer
.deserialize_any(MustBeBoolVisitor(T))
.map(|()| MustBeBool)
}
}

#[cfg(test)]
mod tests {
use super::MustBeBool;

use serde::Deserialize;

#[derive(Deserialize)]
struct MTrue {
#[allow(unused)]
m: MustBeBool<true>,
}

#[derive(Deserialize)]
struct MFalse {
#[allow(unused)]
m: MustBeBool<false>,
}

#[derive(Deserialize)]
#[serde(untagged)]
enum TestEnum {
VariantTrue(MTrue),
VariantFalse(MFalse),
}

#[test]
#[allow(unused)]
fn true_false_enum() {
let json_1 = r#"{ "m": false }"#;
let result_1 = serde_json::from_str::<TestEnum>(json_1).unwrap();
assert!(matches!(result_1, TestEnum::VariantFalse(_)));

let json_2 = r#"{ "m": true }"#;
let result_2 = serde_json::from_str::<TestEnum>(json_2).unwrap();
assert!(matches!(result_2, TestEnum::VariantTrue(_)));
}
}

0 comments on commit f58b540

Please sign in to comment.