Skip to content

Commit

Permalink
fix(model): Unavailable guild must always have unavailable as true
Browse files Browse the repository at this point in the history
This adds a new construct that can be used to only deserialize a
specific boolean value. We then use that value to ensure that only
guilds which are actually unavailable get deserialized as such.
  • Loading branch information
Erk- committed Sep 4, 2024
1 parent 5a9da42 commit aa3cb6e
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 9 deletions.
8 changes: 5 additions & 3 deletions twilight-model/src/gateway/payload/incoming/guild_create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ use crate::{
};
use serde::{Deserialize, Serialize};

// Developer note: Do not change order as we want unavailable to fail
// first.
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(untagged)]
pub enum GuildCreate {
Available(Guild),
Unavailable(UnavailableGuild),
Available(Guild),
}

impl GuildCreate {
Expand All @@ -25,15 +27,15 @@ impl GuildCreate {
mod tests {
use serde_test::Token;

use crate::{guild::UnavailableGuild, id::Id};
use crate::{guild::UnavailableGuild, id::Id, util::mustbe::MustBeBool};

use super::GuildCreate;

#[test]
fn unavailable_guild() {
let expected = GuildCreate::Unavailable(UnavailableGuild {
id: Id::new(1234),
unavailable: true,
unavailable: MustBeBool,
});

// Note: serde(untagged) makes the enum transparent which is
Expand Down
5 changes: 3 additions & 2 deletions twilight-model/src/gateway/payload/incoming/ready.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ mod tests {
id::Id,
oauth::{ApplicationFlags, PartialApplication},
user::CurrentUser,
util::mustbe::MustBeBool,
};
use serde_test::Token;

Expand All @@ -34,11 +35,11 @@ mod tests {
let guilds = vec![
UnavailableGuild {
id: Id::new(1),
unavailable: true,
unavailable: MustBeBool,
},
UnavailableGuild {
id: Id::new(2),
unavailable: true,
unavailable: MustBeBool,
},
];

Expand Down
11 changes: 7 additions & 4 deletions twilight-model/src/guild/unavailable_guild.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
use crate::id::{marker::GuildMarker, Id};
use crate::{
id::{marker::GuildMarker, Id},
util::mustbe::MustBeBool,
};
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub struct UnavailableGuild {
pub id: Id<GuildMarker>,
pub unavailable: bool,
pub unavailable: MustBeBool<true>,
}

#[cfg(test)]
mod tests {
use super::UnavailableGuild;
use crate::id::Id;
use crate::{id::Id, util::mustbe::MustBeBool};
use serde_test::Token;

#[test]
fn unavailable_guild() {
let value = UnavailableGuild {
id: Id::new(1),
unavailable: true,
unavailable: MustBeBool,
};

serde_test::assert_tokens(
Expand Down
1 change: 1 addition & 0 deletions twilight-model/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
pub mod datetime;
pub mod hex_color;
pub mod image_hash;
pub mod mustbe;

pub use self::{datetime::Timestamp, hex_color::HexColor, image_hash::ImageHash};

Expand Down
152 changes: 152 additions & 0 deletions twilight-model/src/util/mustbe.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
//! A struct that only deserializes from one specific boolean value.
//!
//! This module is heavily based upon
//! <https://github.com/dtolnay/monostate>.

use std::{
fmt::{self, Debug},
hash::Hash,
};

use serde::{
de::{Error, Unexpected, Visitor},
Deserialize, Serialize,
};

/// Struct that will only serialize from the bool specified as `T`.
#[derive(Clone, Copy, Default)]
pub struct MustBeBool<const T: bool>;

impl<const T: bool> MustBeBool<T> {
/// Get the expected boolean
pub const fn get(self) -> bool {
T
}
}

impl<const T: bool, const U: bool> PartialEq<MustBeBool<U>> for MustBeBool<T> {
fn eq(&self, _: &MustBeBool<U>) -> bool {
T.eq(&U)
}
}

impl<const T: bool> Eq for MustBeBool<T> {}

impl<const T: bool> Debug for MustBeBool<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("MustBeBool").field(&T).finish()
}
}

impl<const T: bool> Hash for MustBeBool<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
T.hash(state)
}
}

impl<'de, const T: bool> Deserialize<'de> for MustBeBool<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct MustBeBoolVisitor(bool);

impl<'de> Visitor<'de> for MustBeBoolVisitor {
type Value = ();

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "boolean `{}`", self.0)
}

fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
where
E: Error,
{
if v == self.0 {
Ok(())
} else {
Err(E::invalid_value(Unexpected::Bool(v), &self))
}
}
}

deserializer
.deserialize_any(MustBeBoolVisitor(T))
.map(|()| MustBeBool)
}
}

impl<const T: bool> Serialize for MustBeBool<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_bool(T)
}
}

#[cfg(test)]
mod tests {
use super::MustBeBool;

use serde::{Deserialize, Serialize};

#[derive(Deserialize, Serialize)]
struct MTrue {
m: MustBeBool<true>,
}

#[derive(Deserialize, Serialize)]
struct MFalse {
m: MustBeBool<false>,
}

#[derive(Deserialize, Serialize)]
#[serde(untagged)]
enum TestEnum {
VariantTrue(MTrue),
VariantFalse(MFalse),
}

#[test]
#[allow(unused)]
fn true_false_enum() {
let json_1 = r#"{ "m": false }"#;
let result_1 = serde_json::from_str::<TestEnum>(json_1).unwrap();
assert!(matches!(result_1, TestEnum::VariantFalse(_)));

let json_2 = r#"{ "m": true }"#;
let result_2 = serde_json::from_str::<TestEnum>(json_2).unwrap();
assert!(matches!(result_2, TestEnum::VariantTrue(_)));
}

#[test]
fn default_value() {
#[derive(Deserialize, Serialize)]
struct MFalse {
#[serde(default)]
m: MustBeBool<false>,
}

let json_1 = r#"{}"#;
serde_json::from_str::<MFalse>(json_1).unwrap();
}

#[test]
fn ser() {
let val = TestEnum::VariantTrue(MTrue { m: MustBeBool });
let result = serde_json::to_string(&val).unwrap();
assert_eq!(r#"{"m":true}"#, result);

let val = TestEnum::VariantFalse(MFalse { m: MustBeBool });
let result = serde_json::to_string(&val).unwrap();
assert_eq!(r#"{"m":false}"#, result);
}

#[test]
fn equality() {
assert_ne!(MustBeBool::<false>, MustBeBool::<true>);
assert_eq!(MustBeBool::<false>, MustBeBool::<false>);
assert_eq!(MustBeBool::<true>, MustBeBool::<true>);
}
}

0 comments on commit aa3cb6e

Please sign in to comment.