Skip to content

Commit

Permalink
fix: Fix serialization of Unknown(DataModel) (#214)
Browse files Browse the repository at this point in the history
* Add tests for unknown types

* Fix serialization of DataModel
  • Loading branch information
sugyan authored Aug 14, 2024
1 parent 055d6b6 commit e5b7cb9
Showing 1 changed file with 48 additions and 21 deletions.
69 changes: 48 additions & 21 deletions atrium-api/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
use crate::error::Error;
use ipld_core::ipld::Ipld;
use ipld_core::serde::to_ipld;
use serde::{de, ser};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::fmt;
use std::ops::{Deref, DerefMut};
Expand All @@ -25,7 +27,7 @@ pub trait Collection: fmt::Debug {
const NSID: &'static str;

/// This collection's record type.
type Record: fmt::Debug + serde::de::DeserializeOwned + serde::Serialize;
type Record: fmt::Debug + de::DeserializeOwned + Serialize;

/// Returns the [`Nsid`] for the Lexicon that defines the schema of records in this
/// collection.
Expand Down Expand Up @@ -60,30 +62,30 @@ pub trait Collection: fmt::Debug {
/// Definitions for Blob types.
/// Usually a map with `$type` is used, but deprecated legacy formats are also supported for parsing.
/// <https://atproto.com/specs/data-model#blob-type>
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(untagged)]
pub enum BlobRef {
Typed(TypedBlobRef),
Untyped(UnTypedBlobRef),
}

/// Current, typed blob reference.
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(tag = "$type", rename_all = "lowercase")]
pub enum TypedBlobRef {
Blob(Blob),
}

/// An untyped blob reference.
/// Some records in the wild still contain this format, but should never write them.
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct UnTypedBlobRef {
pub cid: String,
pub mime_type: String,
}

#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct Blob {
pub r#ref: CidLink,
Expand All @@ -92,7 +94,7 @@ pub struct Blob {
}

/// A generic object type.
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct Object<T> {
#[serde(flatten)]
pub data: T,
Expand Down Expand Up @@ -124,7 +126,7 @@ impl<T> DerefMut for Object<T> {
}

/// An "open" union type.
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(untagged)]
pub enum Union<T> {
Refs(T),
Expand All @@ -134,7 +136,7 @@ pub enum Union<T> {
/// Data with an unknown schema in an open [`Union`].
///
/// The data of variants represented by a map and include a `$type` field indicating the variant type.
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct UnknownData {
#[serde(rename = "$type")]
pub r#type: String,
Expand All @@ -147,17 +149,28 @@ pub struct UnknownData {
/// Corresponds to [the `unknown` field type].
///
/// [the `unknown` field type]: https://atproto.com/specs/lexicon#unknown
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(untagged)]
pub enum Unknown {
Object(BTreeMap<String, DataModel>),
Null,
Other(DataModel),
}

#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(try_from = "Ipld")]
pub struct DataModel(Ipld);
pub struct DataModel(#[serde(serialize_with = "serialize_data_model")] Ipld);

fn serialize_data_model<S>(ipld: &Ipld, serializer: S) -> Result<S::Ok, S::Error>
where
S: ser::Serializer,
{
match ipld {
Ipld::Float(_) => Err(serde::ser::Error::custom("float values are not allowed")),
Ipld::Link(link) => CidLink(*link).serialize(serializer),
_ => ipld.serialize(serializer),
}
}

impl Deref for DataModel {
type Target = Ipld;
Expand Down Expand Up @@ -209,7 +222,7 @@ pub trait TryFromUnknown: Sized {

impl<T> TryFromUnknown for T
where
T: serde::de::DeserializeOwned,
T: de::DeserializeOwned,
{
type Error = Error;

Expand Down Expand Up @@ -243,7 +256,7 @@ pub trait TryIntoUnknown {

impl<T> TryIntoUnknown for T
where
T: serde::Serialize,
T: Serialize,
{
type Error = Error;

Expand Down Expand Up @@ -384,7 +397,7 @@ mod tests {

#[test]
fn union() {
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(tag = "$type")]
enum FooRefs {
#[serde(rename = "example.com#bar")]
Expand All @@ -393,12 +406,12 @@ mod tests {
Baz(Box<Baz>),
}

#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
struct Bar {
bar: String,
}

#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
struct Baz {
baz: i32,
}
Expand Down Expand Up @@ -434,7 +447,7 @@ mod tests {

#[test]
fn unknown_serialize() {
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
struct Foo {
foo: Unknown,
}
Expand All @@ -451,7 +464,7 @@ mod tests {

#[test]
fn unknown_deserialize() {
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
struct Foo {
foo: Unknown,
}
Expand Down Expand Up @@ -544,7 +557,7 @@ mod tests {

#[test]
fn unknown_try_from() {
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(tag = "$type")]
enum Foo {
#[serde(rename = "example.com#bar")]
Expand All @@ -553,12 +566,12 @@ mod tests {
Baz(Box<Baz>),
}

#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
struct Bar {
bar: String,
}

#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
struct Baz {
baz: i32,
}
Expand Down Expand Up @@ -603,4 +616,18 @@ mod tests {
assert_eq!(barbaz, Foo::Baz(Box::new(Baz { baz: 42 })));
}
}

#[test]
fn serialize_unknown_from_cid_link() {
let cid_link =
CidLink::try_from("bafkreibme22gw2h7y2h7tg2fhqotaqjucnbc24deqo72b6mkl2egezxhvy")
.expect("failed to create cid-link");
let unknown = cid_link
.try_into_unknown()
.expect("failed to convert to unknown");
assert_eq!(
serde_json::to_string(&unknown).expect("failed to serialize unknown"),
r#"{"$link":"bafkreibme22gw2h7y2h7tg2fhqotaqjucnbc24deqo72b6mkl2egezxhvy"}"#
);
}
}

0 comments on commit e5b7cb9

Please sign in to comment.