From 4e90832a26df3772728ac7093ed04941e454117c Mon Sep 17 00:00:00 2001
From: Kyle Kotowick <kotowick@invictonlabs.com>
Date: Thu, 30 Jan 2025 14:04:43 -0500
Subject: [PATCH] Use macro for various Dilithium functions

---
 Cargo.toml |   2 +
 src/lib.rs | 632 ++++++++++++++---------------------------------------
 2 files changed, 165 insertions(+), 469 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 670fba7..4cd39ec 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -16,6 +16,8 @@ heapless-bytes = "0.3.0"
 pqcrypto-dilithium = { version = "0.5.0", optional = true}
 serde-big-array = "0.5.1"
 serde_repr = "0.1"
+paste = "1.0"
+with_builtin_macros = "0.1.0"
 
 [dependencies.serde]
 version = "1.0"
diff --git a/src/lib.rs b/src/lib.rs
index 31a4baf..823bfda 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -44,9 +44,10 @@
        * label => values
    }
 */
-
+use ::with_builtin_macros::with_eager_expansions;
 use core::fmt::{self, Formatter};
 pub use heapless_bytes::Bytes;
+use paste::paste;
 #[cfg(feature = "backend-dilithium2")]
 use pqcrypto_dilithium::dilithium2;
 #[cfg(feature = "backend-dilithium3")]
@@ -198,26 +199,6 @@ impl From<TotpPublicKey> for PublicKey {
     }
 }
 
-#[cfg(feature = "backend-dilithium2")]
-impl From<Dilithium2PublicKey> for PublicKey {
-    fn from(key: Dilithium2PublicKey) -> Self {
-        PublicKey::Dilithium2(key)
-    }
-}
-
-#[cfg(feature = "backend-dilithium3")]
-impl From<Dilithium3PublicKey> for PublicKey {
-    fn from(key: Dilithium3PublicKey) -> Self {
-        PublicKey::Dilithium3(key)
-    }
-}
-#[cfg(feature = "backend-dilithium5")]
-impl From<Dilithium5PublicKey> for PublicKey {
-    fn from(key: Dilithium5PublicKey) -> Self {
-        PublicKey::Dilithium5(key)
-    }
-}
-
 #[derive(Clone, Debug, Default)]
 struct RawEcPublicKey {
     kty: Option<Kty>,
@@ -227,30 +208,6 @@ struct RawEcPublicKey {
     y: Option<Bytes<32>>,
 }
 
-#[derive(Clone, Debug, Default)]
-#[cfg(feature = "backend-dilithium2")]
-struct RawDilithium2PublicKey {
-    kty: Option<Kty>,
-    alg: Option<Alg>,
-    pk: Option<Bytes<{ dilithium2::public_key_bytes() }>>,
-}
-
-#[derive(Clone, Debug, Default)]
-#[cfg(feature = "backend-dilithium3")]
-struct RawDilithium3PublicKey {
-    kty: Option<Kty>,
-    alg: Option<Alg>,
-    pk: Option<Bytes<{ dilithium3::public_key_bytes() }>>,
-}
-
-#[derive(Clone, Debug, Default)]
-#[cfg(feature = "backend-dilithium5")]
-struct RawDilithium5PublicKey {
-    kty: Option<Kty>,
-    alg: Option<Alg>,
-    pk: Option<Bytes<{ dilithium5::public_key_bytes() }>>,
-}
-
 impl<'de> Deserialize<'de> for RawEcPublicKey {
     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
     where
@@ -377,318 +334,6 @@ impl Serialize for RawEcPublicKey {
     }
 }
 
-#[cfg(feature = "backend-dilithium2")]
-impl<'de> Deserialize<'de> for RawDilithium2PublicKey {
-    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
-    where
-        D: serde::Deserializer<'de>,
-    {
-        struct IndexedVisitor;
-        impl<'de> serde::de::Visitor<'de> for IndexedVisitor {
-            type Value = RawDilithium2PublicKey;
-
-            fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
-                formatter.write_str("RawDilithium2PublicKey")
-            }
-
-            fn visit_map<V>(self, mut map: V) -> Result<RawDilithium2PublicKey, V::Error>
-            where
-                V: MapAccess<'de>,
-            {
-                #[derive(PartialEq)]
-                enum Key {
-                    Label(Label),
-                    Unknown(i8),
-                    None,
-                }
-
-                fn next_key<'a, V: MapAccess<'a>>(map: &mut V) -> Result<Key, V::Error> {
-                    let key: Option<i8> = map.next_key()?;
-                    let key = match key {
-                        Some(key) => match Label::try_from(key) {
-                            Ok(label) => Key::Label(label),
-                            Err(_) => Key::Unknown(key),
-                        },
-                        None => Key::None,
-                    };
-                    Ok(key)
-                }
-
-                let mut public_key = RawDilithium2PublicKey::default();
-
-                // As we cannot deserialize arbitrary values with cbor-smol, we do not support
-                // unknown keys before a known key.  If there are unknown keys, they must be at the
-                // end.
-
-                // only deserialize in canonical order
-
-                let mut key = next_key(&mut map)?;
-
-                if key == Key::Label(Label::Kty) {
-                    public_key.kty = Some(map.next_value()?);
-                    key = next_key(&mut map)?;
-                }
-
-                if key == Key::Label(Label::Alg) {
-                    public_key.alg = Some(map.next_value()?);
-                    key = next_key(&mut map)?;
-                }
-
-                if key == Key::Label(Label::CrvOrPk) {
-                    public_key.pk = Some(map.next_value()?);
-                    key = next_key(&mut map)?;
-                }
-
-                // if there is another key, it should be an unknown one
-                if matches!(key, Key::Label(_)) {
-                    Err(serde::de::Error::custom(
-                        "public key data in wrong order or with duplicates",
-                    ))
-                } else {
-                    Ok(public_key)
-                }
-            }
-        }
-        deserializer.deserialize_map(IndexedVisitor {})
-    }
-}
-
-#[cfg(feature = "backend-dilithium2")]
-impl Serialize for RawDilithium2PublicKey {
-    fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
-    where
-        S: serde::Serializer,
-    {
-        let is_set = [self.kty.is_some(), self.alg.is_some(), self.pk.is_some()];
-        let fields = is_set.into_iter().map(usize::from).sum();
-        use serde::ser::SerializeMap;
-        let mut map = serializer.serialize_map(Some(fields))?;
-
-        //  1: kty
-        if let Some(kty) = &self.kty {
-            map.serialize_entry(&(Label::Kty as i8), &(*kty as i8))?;
-        }
-        //  3: alg
-        if let Some(alg) = &self.alg {
-            map.serialize_entry(&(Label::Alg as i8), &(*alg as i8))?;
-        }
-        // -1: pk
-        if let Some(pk) = &self.pk {
-            map.serialize_entry(&(Label::CrvOrPk as i8), pk)?;
-        }
-
-        map.end()
-    }
-}
-
-#[cfg(feature = "backend-dilithium3")]
-impl<'de> Deserialize<'de> for RawDilithium3PublicKey {
-    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
-    where
-        D: serde::Deserializer<'de>,
-    {
-        struct IndexedVisitor;
-        impl<'de> serde::de::Visitor<'de> for IndexedVisitor {
-            type Value = RawDilithium3PublicKey;
-
-            fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
-                formatter.write_str("RawDilithium2PublicKey")
-            }
-
-            fn visit_map<V>(self, mut map: V) -> Result<RawDilithium3PublicKey, V::Error>
-            where
-                V: MapAccess<'de>,
-            {
-                #[derive(PartialEq)]
-                enum Key {
-                    Label(Label),
-                    Unknown(i8),
-                    None,
-                }
-
-                fn next_key<'a, V: MapAccess<'a>>(map: &mut V) -> Result<Key, V::Error> {
-                    let key: Option<i8> = map.next_key()?;
-                    let key = match key {
-                        Some(key) => match Label::try_from(key) {
-                            Ok(label) => Key::Label(label),
-                            Err(_) => Key::Unknown(key),
-                        },
-                        None => Key::None,
-                    };
-                    Ok(key)
-                }
-
-                let mut public_key = RawDilithium3PublicKey::default();
-
-                // As we cannot deserialize arbitrary values with cbor-smol, we do not support
-                // unknown keys before a known key.  If there are unknown keys, they must be at the
-                // end.
-
-                // only deserialize in canonical order
-
-                let mut key = next_key(&mut map)?;
-
-                if key == Key::Label(Label::Kty) {
-                    public_key.kty = Some(map.next_value()?);
-                    key = next_key(&mut map)?;
-                }
-
-                if key == Key::Label(Label::Alg) {
-                    public_key.alg = Some(map.next_value()?);
-                    key = next_key(&mut map)?;
-                }
-
-                if key == Key::Label(Label::CrvOrPk) {
-                    public_key.pk = Some(map.next_value()?);
-                    key = next_key(&mut map)?;
-                }
-
-                // if there is another key, it should be an unknown one
-                if matches!(key, Key::Label(_)) {
-                    Err(serde::de::Error::custom(
-                        "public key data in wrong order or with duplicates",
-                    ))
-                } else {
-                    Ok(public_key)
-                }
-            }
-        }
-        deserializer.deserialize_map(IndexedVisitor {})
-    }
-}
-
-#[cfg(feature = "backend-dilithium3")]
-impl Serialize for RawDilithium3PublicKey {
-    fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
-    where
-        S: serde::Serializer,
-    {
-        let is_set = [self.kty.is_some(), self.alg.is_some(), self.pk.is_some()];
-        let fields = is_set.into_iter().map(usize::from).sum();
-        use serde::ser::SerializeMap;
-        let mut map = serializer.serialize_map(Some(fields))?;
-
-        //  1: kty
-        if let Some(kty) = &self.kty {
-            map.serialize_entry(&(Label::Kty as i8), &(*kty as i8))?;
-        }
-        //  3: alg
-        if let Some(alg) = &self.alg {
-            map.serialize_entry(&(Label::Alg as i8), &(*alg as i8))?;
-        }
-        // -1: pk
-        if let Some(pk) = &self.pk {
-            map.serialize_entry(&(Label::CrvOrPk as i8), pk)?;
-        }
-
-        map.end()
-    }
-}
-
-#[cfg(feature = "backend-dilithium5")]
-impl<'de> Deserialize<'de> for RawDilithium5PublicKey {
-    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
-    where
-        D: serde::Deserializer<'de>,
-    {
-        struct IndexedVisitor;
-        impl<'de> serde::de::Visitor<'de> for IndexedVisitor {
-            type Value = RawDilithium5PublicKey;
-
-            fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
-                formatter.write_str("RawDilithium2PublicKey")
-            }
-
-            fn visit_map<V>(self, mut map: V) -> Result<RawDilithium5PublicKey, V::Error>
-            where
-                V: MapAccess<'de>,
-            {
-                #[derive(PartialEq)]
-                enum Key {
-                    Label(Label),
-                    Unknown(i8),
-                    None,
-                }
-
-                fn next_key<'a, V: MapAccess<'a>>(map: &mut V) -> Result<Key, V::Error> {
-                    let key: Option<i8> = map.next_key()?;
-                    let key = match key {
-                        Some(key) => match Label::try_from(key) {
-                            Ok(label) => Key::Label(label),
-                            Err(_) => Key::Unknown(key),
-                        },
-                        None => Key::None,
-                    };
-                    Ok(key)
-                }
-
-                let mut public_key = RawDilithium5PublicKey::default();
-
-                // As we cannot deserialize arbitrary values with cbor-smol, we do not support
-                // unknown keys before a known key.  If there are unknown keys, they must be at the
-                // end.
-
-                // only deserialize in canonical order
-
-                let mut key = next_key(&mut map)?;
-
-                if key == Key::Label(Label::Kty) {
-                    public_key.kty = Some(map.next_value()?);
-                    key = next_key(&mut map)?;
-                }
-
-                if key == Key::Label(Label::Alg) {
-                    public_key.alg = Some(map.next_value()?);
-                    key = next_key(&mut map)?;
-                }
-
-                if key == Key::Label(Label::CrvOrPk) {
-                    public_key.pk = Some(map.next_value()?);
-                    key = next_key(&mut map)?;
-                }
-
-                // if there is another key, it should be an unknown one
-                if matches!(key, Key::Label(_)) {
-                    Err(serde::de::Error::custom(
-                        "public key data in wrong order or with duplicates",
-                    ))
-                } else {
-                    Ok(public_key)
-                }
-            }
-        }
-        deserializer.deserialize_map(IndexedVisitor {})
-    }
-}
-
-#[cfg(feature = "backend-dilithium5")]
-impl Serialize for RawDilithium5PublicKey {
-    fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
-    where
-        S: serde::Serializer,
-    {
-        let is_set = [self.kty.is_some(), self.alg.is_some(), self.pk.is_some()];
-        let fields = is_set.into_iter().map(usize::from).sum();
-        use serde::ser::SerializeMap;
-        let mut map = serializer.serialize_map(Some(fields))?;
-
-        //  1: kty
-        if let Some(kty) = &self.kty {
-            map.serialize_entry(&(Label::Kty as i8), &(*kty as i8))?;
-        }
-        //  3: alg
-        if let Some(alg) = &self.alg {
-            map.serialize_entry(&(Label::Alg as i8), &(*alg as i8))?;
-        }
-        // -1: pk
-        if let Some(pk) = &self.pk {
-            map.serialize_entry(&(Label::CrvOrPk as i8), pk)?;
-        }
-
-        map.end()
-    }
-}
-
 trait PublicKeyConstants {
     const KTY: Kty;
     const ALG: Alg;
@@ -769,81 +414,6 @@ impl From<Ed25519PublicKey> for RawEcPublicKey {
     }
 }
 
-#[cfg(feature = "backend-dilithium2")]
-#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
-#[serde(into = "RawDilithium2PublicKey")]
-pub struct Dilithium2PublicKey {
-    pub pk: Bytes<{ dilithium2::public_key_bytes() }>,
-}
-
-#[cfg(feature = "backend-dilithium2")]
-impl PublicKeyConstants for Dilithium2PublicKey {
-    const KTY: Kty = Kty::Pqc;
-    const ALG: Alg = Alg::Dilithium2;
-    const CRV: Crv = Crv::None;
-}
-
-#[cfg(feature = "backend-dilithium2")]
-impl From<Dilithium2PublicKey> for RawDilithium2PublicKey {
-    fn from(key: Dilithium2PublicKey) -> Self {
-        Self {
-            kty: Some(Dilithium2PublicKey::KTY),
-            alg: Some(Dilithium2PublicKey::ALG),
-            pk: Some(key.pk),
-        }
-    }
-}
-
-#[cfg(feature = "backend-dilithium3")]
-#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
-#[serde(into = "RawDilithium3PublicKey")]
-pub struct Dilithium3PublicKey {
-    pub pk: Bytes<{ dilithium3::public_key_bytes() }>,
-}
-
-#[cfg(feature = "backend-dilithium3")]
-impl PublicKeyConstants for Dilithium3PublicKey {
-    const KTY: Kty = Kty::Pqc;
-    const ALG: Alg = Alg::Dilithium3;
-    const CRV: Crv = Crv::None;
-}
-
-#[cfg(feature = "backend-dilithium3")]
-impl From<Dilithium3PublicKey> for RawDilithium3PublicKey {
-    fn from(key: Dilithium3PublicKey) -> Self {
-        Self {
-            kty: Some(Dilithium3PublicKey::KTY),
-            alg: Some(Dilithium3PublicKey::ALG),
-            pk: Some(key.pk),
-        }
-    }
-}
-
-#[cfg(feature = "backend-dilithium5")]
-#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
-#[serde(into = "RawDilithium5PublicKey")]
-pub struct Dilithium5PublicKey {
-    pub pk: Bytes<{ dilithium5::public_key_bytes() }>,
-}
-
-#[cfg(feature = "backend-dilithium5")]
-impl PublicKeyConstants for Dilithium5PublicKey {
-    const KTY: Kty = Kty::Pqc;
-    const ALG: Alg = Alg::Dilithium5;
-    const CRV: Crv = Crv::None;
-}
-
-#[cfg(feature = "backend-dilithium5")]
-impl From<Dilithium5PublicKey> for RawDilithium5PublicKey {
-    fn from(key: Dilithium5PublicKey) -> Self {
-        Self {
-            kty: Some(Dilithium5PublicKey::KTY),
-            alg: Some(Dilithium5PublicKey::ALG),
-            pk: Some(key.pk),
-        }
-    }
-}
-
 #[derive(Clone, Debug, Default, Eq, PartialEq, Serialize)]
 #[serde(into = "RawEcPublicKey")]
 pub struct TotpPublicKey {}
@@ -946,44 +516,168 @@ impl<'de> serde::Deserialize<'de> for Ed25519PublicKey {
     }
 }
 
-#[cfg(feature = "backend-dilithium2")]
-impl<'de> serde::Deserialize<'de> for Dilithium2PublicKey {
-    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
-    where
-        D: serde::Deserializer<'de>,
-    {
-        let RawDilithium2PublicKey { kty, alg, pk, .. } =
-            RawDilithium2PublicKey::deserialize(deserializer)?;
-        check_key_constants::<Dilithium2PublicKey, D::Error>(kty, alg, Some(Crv::None))?;
-        let pk = pk.ok_or_else(|| D::Error::missing_field("pk"))?;
-        Ok(Self { pk })
-    }
-}
+#[cfg(feature = "backend-dilithium")]
+macro_rules! dilithium_public_key {
+    ($dilithium_number: tt) => {
+        paste! {
+            with_eager_expansions! {
+                #[derive(Clone, Debug, Eq, PartialEq, Serialize)]
+                #[serde(into = #{ concat!("RawDilithium", stringify!($dilithium_number), "PublicKey") })]
+                pub struct [<Dilithium $dilithium_number PublicKey>] {
+                    pub pk: Bytes<{ [<dilithium $dilithium_number>]::public_key_bytes() }>,
+                }
 
-#[cfg(feature = "backend-dilithium3")]
-impl<'de> serde::Deserialize<'de> for Dilithium3PublicKey {
-    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
-    where
-        D: serde::Deserializer<'de>,
-    {
-        let RawDilithium3PublicKey { kty, alg, pk, .. } =
-            RawDilithium3PublicKey::deserialize(deserializer)?;
-        check_key_constants::<Dilithium3PublicKey, D::Error>(kty, alg, Some(Crv::None))?;
-        let pk = pk.ok_or_else(|| D::Error::missing_field("pk"))?;
-        Ok(Self { pk })
-    }
+                impl PublicKeyConstants for [<Dilithium $dilithium_number PublicKey>] {
+                    const KTY: Kty = Kty::Pqc;
+                    const ALG: Alg = Alg::[<Dilithium $dilithium_number>];
+                    const CRV: Crv = Crv::None;
+                }
+
+                impl From<[<Dilithium $dilithium_number PublicKey>]> for PublicKey {
+                    fn from(key: [<Dilithium $dilithium_number PublicKey>]) -> Self {
+                        PublicKey::[<Dilithium $dilithium_number>](key)
+                    }
+                }
+
+                #[derive(Clone, Debug, Default)]
+                struct [<RawDilithium $dilithium_number PublicKey>] {
+                    kty: Option<Kty>,
+                    alg: Option<Alg>,
+                    pk: Option<Bytes<{ [<dilithium $dilithium_number>]::public_key_bytes() }>>,
+                }
+
+                impl From<[<Dilithium $dilithium_number PublicKey>]> for [<RawDilithium $dilithium_number PublicKey>] {
+                    fn from(key: [<Dilithium $dilithium_number PublicKey>]) -> Self {
+                        Self {
+                            kty: Some([<Dilithium $dilithium_number PublicKey>]::KTY),
+                            alg: Some([<Dilithium $dilithium_number PublicKey>]::ALG),
+                            pk: Some(key.pk),
+                        }
+                    }
+                }
+
+                impl<'de> serde::Deserialize<'de> for [<Dilithium $dilithium_number PublicKey>] {
+                    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+                    where
+                        D: serde::Deserializer<'de>,
+                    {
+                        let [<RawDilithium $dilithium_number PublicKey>] { kty, alg, pk, .. } =
+                        [<RawDilithium $dilithium_number PublicKey>]::deserialize(deserializer)?;
+                        check_key_constants::<[<Dilithium $dilithium_number PublicKey>], D::Error>(kty, alg, Some(Crv::None))?;
+                        let pk = pk.ok_or_else(|| D::Error::missing_field("pk"))?;
+                        Ok(Self { pk })
+                    }
+                }
+
+                impl<'de> Deserialize<'de> for [<RawDilithium $dilithium_number PublicKey>] {
+                    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+                    where
+                        D: serde::Deserializer<'de>,
+                    {
+                        struct IndexedVisitor;
+                        impl<'de> serde::de::Visitor<'de> for IndexedVisitor {
+                            type Value = [<RawDilithium $dilithium_number PublicKey>];
+
+                            fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
+                                formatter.write_str(concat!("RawDilithium", stringify!($dilithium_number), "PublicKey"))
+                            }
+
+                            fn visit_map<V>(self, mut map: V) -> Result<[<RawDilithium $dilithium_number PublicKey>], V::Error>
+                            where
+                                V: MapAccess<'de>,
+                            {
+                                #[derive(PartialEq)]
+                                enum Key {
+                                    Label(Label),
+                                    Unknown(i8),
+                                    None,
+                                }
+
+                                fn next_key<'a, V: MapAccess<'a>>(map: &mut V) -> Result<Key, V::Error> {
+                                    let key: Option<i8> = map.next_key()?;
+                                    let key = match key {
+                                        Some(key) => match Label::try_from(key) {
+                                            Ok(label) => Key::Label(label),
+                                            Err(_) => Key::Unknown(key),
+                                        },
+                                        None => Key::None,
+                                    };
+                                    Ok(key)
+                                }
+
+                                let mut public_key = [<RawDilithium $dilithium_number PublicKey>]::default();
+
+                                // As we cannot deserialize arbitrary values with cbor-smol, we do not support
+                                // unknown keys before a known key.  If there are unknown keys, they must be at the
+                                // end.
+
+                                // only deserialize in canonical order
+
+                                let mut key = next_key(&mut map)?;
+
+                                if key == Key::Label(Label::Kty) {
+                                    public_key.kty = Some(map.next_value()?);
+                                    key = next_key(&mut map)?;
+                                }
+
+                                if key == Key::Label(Label::Alg) {
+                                    public_key.alg = Some(map.next_value()?);
+                                    key = next_key(&mut map)?;
+                                }
+
+                                if key == Key::Label(Label::CrvOrPk) {
+                                    public_key.pk = Some(map.next_value()?);
+                                    key = next_key(&mut map)?;
+                                }
+
+                                // if there is another key, it should be an unknown one
+                                if matches!(key, Key::Label(_)) {
+                                    Err(serde::de::Error::custom(
+                                        "public key data in wrong order or with duplicates",
+                                    ))
+                                } else {
+                                    Ok(public_key)
+                                }
+                            }
+                        }
+                        deserializer.deserialize_map(IndexedVisitor {})
+                    }
+                }
+
+                impl Serialize for [<RawDilithium $dilithium_number PublicKey>] {
+                    fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
+                    where
+                        S: serde::Serializer,
+                    {
+                        let is_set = [self.kty.is_some(), self.alg.is_some(), self.pk.is_some()];
+                        let fields = is_set.into_iter().map(usize::from).sum();
+                        use serde::ser::SerializeMap;
+                        let mut map = serializer.serialize_map(Some(fields))?;
+
+                        //  1: kty
+                        if let Some(kty) = &self.kty {
+                            map.serialize_entry(&(Label::Kty as i8), &(*kty as i8))?;
+                        }
+                        //  3: alg
+                        if let Some(alg) = &self.alg {
+                            map.serialize_entry(&(Label::Alg as i8), &(*alg as i8))?;
+                        }
+                        // -1: pk
+                        if let Some(pk) = &self.pk {
+                            map.serialize_entry(&(Label::CrvOrPk as i8), pk)?;
+                        }
+
+                        map.end()
+                    }
+                }
+            }
+        }
+    };
 }
 
+#[cfg(feature = "backend-dilithium2")]
+dilithium_public_key!(2);
+#[cfg(feature = "backend-dilithium3")]
+dilithium_public_key!(3);
 #[cfg(feature = "backend-dilithium5")]
-impl<'de> serde::Deserialize<'de> for Dilithium5PublicKey {
-    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
-    where
-        D: serde::Deserializer<'de>,
-    {
-        let RawDilithium5PublicKey { kty, alg, pk, .. } =
-            RawDilithium5PublicKey::deserialize(deserializer)?;
-        check_key_constants::<Dilithium5PublicKey, D::Error>(kty, alg, Some(Crv::None))?;
-        let pk = pk.ok_or_else(|| D::Error::missing_field("pk"))?;
-        Ok(Self { pk })
-    }
-}
+dilithium_public_key!(5);