From 4e90832a26df3772728ac7093ed04941e454117c Mon Sep 17 00:00:00 2001 From: Kyle Kotowick <kotowick@invictonlabs.com> Date: Thu, 30 Jan 2025 14:04:43 -0500 Subject: [PATCH] Use macro for various Dilithium functions --- Cargo.toml | 2 + src/lib.rs | 632 ++++++++++++++--------------------------------------- 2 files changed, 165 insertions(+), 469 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 670fba7..4cd39ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,8 @@ heapless-bytes = "0.3.0" pqcrypto-dilithium = { version = "0.5.0", optional = true} serde-big-array = "0.5.1" serde_repr = "0.1" +paste = "1.0" +with_builtin_macros = "0.1.0" [dependencies.serde] version = "1.0" diff --git a/src/lib.rs b/src/lib.rs index 31a4baf..823bfda 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,9 +44,10 @@ * label => values } */ - +use ::with_builtin_macros::with_eager_expansions; use core::fmt::{self, Formatter}; pub use heapless_bytes::Bytes; +use paste::paste; #[cfg(feature = "backend-dilithium2")] use pqcrypto_dilithium::dilithium2; #[cfg(feature = "backend-dilithium3")] @@ -198,26 +199,6 @@ impl From<TotpPublicKey> for PublicKey { } } -#[cfg(feature = "backend-dilithium2")] -impl From<Dilithium2PublicKey> for PublicKey { - fn from(key: Dilithium2PublicKey) -> Self { - PublicKey::Dilithium2(key) - } -} - -#[cfg(feature = "backend-dilithium3")] -impl From<Dilithium3PublicKey> for PublicKey { - fn from(key: Dilithium3PublicKey) -> Self { - PublicKey::Dilithium3(key) - } -} -#[cfg(feature = "backend-dilithium5")] -impl From<Dilithium5PublicKey> for PublicKey { - fn from(key: Dilithium5PublicKey) -> Self { - PublicKey::Dilithium5(key) - } -} - #[derive(Clone, Debug, Default)] struct RawEcPublicKey { kty: Option<Kty>, @@ -227,30 +208,6 @@ struct RawEcPublicKey { y: Option<Bytes<32>>, } -#[derive(Clone, Debug, Default)] -#[cfg(feature = "backend-dilithium2")] -struct RawDilithium2PublicKey { - kty: Option<Kty>, - alg: Option<Alg>, - pk: Option<Bytes<{ dilithium2::public_key_bytes() }>>, -} - -#[derive(Clone, Debug, Default)] -#[cfg(feature = "backend-dilithium3")] -struct RawDilithium3PublicKey { - kty: Option<Kty>, - alg: Option<Alg>, - pk: Option<Bytes<{ dilithium3::public_key_bytes() }>>, -} - -#[derive(Clone, Debug, Default)] -#[cfg(feature = "backend-dilithium5")] -struct RawDilithium5PublicKey { - kty: Option<Kty>, - alg: Option<Alg>, - pk: Option<Bytes<{ dilithium5::public_key_bytes() }>>, -} - impl<'de> Deserialize<'de> for RawEcPublicKey { fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> where @@ -377,318 +334,6 @@ impl Serialize for RawEcPublicKey { } } -#[cfg(feature = "backend-dilithium2")] -impl<'de> Deserialize<'de> for RawDilithium2PublicKey { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> - where - D: serde::Deserializer<'de>, - { - struct IndexedVisitor; - impl<'de> serde::de::Visitor<'de> for IndexedVisitor { - type Value = RawDilithium2PublicKey; - - fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { - formatter.write_str("RawDilithium2PublicKey") - } - - fn visit_map<V>(self, mut map: V) -> Result<RawDilithium2PublicKey, V::Error> - where - V: MapAccess<'de>, - { - #[derive(PartialEq)] - enum Key { - Label(Label), - Unknown(i8), - None, - } - - fn next_key<'a, V: MapAccess<'a>>(map: &mut V) -> Result<Key, V::Error> { - let key: Option<i8> = map.next_key()?; - let key = match key { - Some(key) => match Label::try_from(key) { - Ok(label) => Key::Label(label), - Err(_) => Key::Unknown(key), - }, - None => Key::None, - }; - Ok(key) - } - - let mut public_key = RawDilithium2PublicKey::default(); - - // As we cannot deserialize arbitrary values with cbor-smol, we do not support - // unknown keys before a known key. If there are unknown keys, they must be at the - // end. - - // only deserialize in canonical order - - let mut key = next_key(&mut map)?; - - if key == Key::Label(Label::Kty) { - public_key.kty = Some(map.next_value()?); - key = next_key(&mut map)?; - } - - if key == Key::Label(Label::Alg) { - public_key.alg = Some(map.next_value()?); - key = next_key(&mut map)?; - } - - if key == Key::Label(Label::CrvOrPk) { - public_key.pk = Some(map.next_value()?); - key = next_key(&mut map)?; - } - - // if there is another key, it should be an unknown one - if matches!(key, Key::Label(_)) { - Err(serde::de::Error::custom( - "public key data in wrong order or with duplicates", - )) - } else { - Ok(public_key) - } - } - } - deserializer.deserialize_map(IndexedVisitor {}) - } -} - -#[cfg(feature = "backend-dilithium2")] -impl Serialize for RawDilithium2PublicKey { - fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error> - where - S: serde::Serializer, - { - let is_set = [self.kty.is_some(), self.alg.is_some(), self.pk.is_some()]; - let fields = is_set.into_iter().map(usize::from).sum(); - use serde::ser::SerializeMap; - let mut map = serializer.serialize_map(Some(fields))?; - - // 1: kty - if let Some(kty) = &self.kty { - map.serialize_entry(&(Label::Kty as i8), &(*kty as i8))?; - } - // 3: alg - if let Some(alg) = &self.alg { - map.serialize_entry(&(Label::Alg as i8), &(*alg as i8))?; - } - // -1: pk - if let Some(pk) = &self.pk { - map.serialize_entry(&(Label::CrvOrPk as i8), pk)?; - } - - map.end() - } -} - -#[cfg(feature = "backend-dilithium3")] -impl<'de> Deserialize<'de> for RawDilithium3PublicKey { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> - where - D: serde::Deserializer<'de>, - { - struct IndexedVisitor; - impl<'de> serde::de::Visitor<'de> for IndexedVisitor { - type Value = RawDilithium3PublicKey; - - fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { - formatter.write_str("RawDilithium2PublicKey") - } - - fn visit_map<V>(self, mut map: V) -> Result<RawDilithium3PublicKey, V::Error> - where - V: MapAccess<'de>, - { - #[derive(PartialEq)] - enum Key { - Label(Label), - Unknown(i8), - None, - } - - fn next_key<'a, V: MapAccess<'a>>(map: &mut V) -> Result<Key, V::Error> { - let key: Option<i8> = map.next_key()?; - let key = match key { - Some(key) => match Label::try_from(key) { - Ok(label) => Key::Label(label), - Err(_) => Key::Unknown(key), - }, - None => Key::None, - }; - Ok(key) - } - - let mut public_key = RawDilithium3PublicKey::default(); - - // As we cannot deserialize arbitrary values with cbor-smol, we do not support - // unknown keys before a known key. If there are unknown keys, they must be at the - // end. - - // only deserialize in canonical order - - let mut key = next_key(&mut map)?; - - if key == Key::Label(Label::Kty) { - public_key.kty = Some(map.next_value()?); - key = next_key(&mut map)?; - } - - if key == Key::Label(Label::Alg) { - public_key.alg = Some(map.next_value()?); - key = next_key(&mut map)?; - } - - if key == Key::Label(Label::CrvOrPk) { - public_key.pk = Some(map.next_value()?); - key = next_key(&mut map)?; - } - - // if there is another key, it should be an unknown one - if matches!(key, Key::Label(_)) { - Err(serde::de::Error::custom( - "public key data in wrong order or with duplicates", - )) - } else { - Ok(public_key) - } - } - } - deserializer.deserialize_map(IndexedVisitor {}) - } -} - -#[cfg(feature = "backend-dilithium3")] -impl Serialize for RawDilithium3PublicKey { - fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error> - where - S: serde::Serializer, - { - let is_set = [self.kty.is_some(), self.alg.is_some(), self.pk.is_some()]; - let fields = is_set.into_iter().map(usize::from).sum(); - use serde::ser::SerializeMap; - let mut map = serializer.serialize_map(Some(fields))?; - - // 1: kty - if let Some(kty) = &self.kty { - map.serialize_entry(&(Label::Kty as i8), &(*kty as i8))?; - } - // 3: alg - if let Some(alg) = &self.alg { - map.serialize_entry(&(Label::Alg as i8), &(*alg as i8))?; - } - // -1: pk - if let Some(pk) = &self.pk { - map.serialize_entry(&(Label::CrvOrPk as i8), pk)?; - } - - map.end() - } -} - -#[cfg(feature = "backend-dilithium5")] -impl<'de> Deserialize<'de> for RawDilithium5PublicKey { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> - where - D: serde::Deserializer<'de>, - { - struct IndexedVisitor; - impl<'de> serde::de::Visitor<'de> for IndexedVisitor { - type Value = RawDilithium5PublicKey; - - fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { - formatter.write_str("RawDilithium2PublicKey") - } - - fn visit_map<V>(self, mut map: V) -> Result<RawDilithium5PublicKey, V::Error> - where - V: MapAccess<'de>, - { - #[derive(PartialEq)] - enum Key { - Label(Label), - Unknown(i8), - None, - } - - fn next_key<'a, V: MapAccess<'a>>(map: &mut V) -> Result<Key, V::Error> { - let key: Option<i8> = map.next_key()?; - let key = match key { - Some(key) => match Label::try_from(key) { - Ok(label) => Key::Label(label), - Err(_) => Key::Unknown(key), - }, - None => Key::None, - }; - Ok(key) - } - - let mut public_key = RawDilithium5PublicKey::default(); - - // As we cannot deserialize arbitrary values with cbor-smol, we do not support - // unknown keys before a known key. If there are unknown keys, they must be at the - // end. - - // only deserialize in canonical order - - let mut key = next_key(&mut map)?; - - if key == Key::Label(Label::Kty) { - public_key.kty = Some(map.next_value()?); - key = next_key(&mut map)?; - } - - if key == Key::Label(Label::Alg) { - public_key.alg = Some(map.next_value()?); - key = next_key(&mut map)?; - } - - if key == Key::Label(Label::CrvOrPk) { - public_key.pk = Some(map.next_value()?); - key = next_key(&mut map)?; - } - - // if there is another key, it should be an unknown one - if matches!(key, Key::Label(_)) { - Err(serde::de::Error::custom( - "public key data in wrong order or with duplicates", - )) - } else { - Ok(public_key) - } - } - } - deserializer.deserialize_map(IndexedVisitor {}) - } -} - -#[cfg(feature = "backend-dilithium5")] -impl Serialize for RawDilithium5PublicKey { - fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error> - where - S: serde::Serializer, - { - let is_set = [self.kty.is_some(), self.alg.is_some(), self.pk.is_some()]; - let fields = is_set.into_iter().map(usize::from).sum(); - use serde::ser::SerializeMap; - let mut map = serializer.serialize_map(Some(fields))?; - - // 1: kty - if let Some(kty) = &self.kty { - map.serialize_entry(&(Label::Kty as i8), &(*kty as i8))?; - } - // 3: alg - if let Some(alg) = &self.alg { - map.serialize_entry(&(Label::Alg as i8), &(*alg as i8))?; - } - // -1: pk - if let Some(pk) = &self.pk { - map.serialize_entry(&(Label::CrvOrPk as i8), pk)?; - } - - map.end() - } -} - trait PublicKeyConstants { const KTY: Kty; const ALG: Alg; @@ -769,81 +414,6 @@ impl From<Ed25519PublicKey> for RawEcPublicKey { } } -#[cfg(feature = "backend-dilithium2")] -#[derive(Clone, Debug, Eq, PartialEq, Serialize)] -#[serde(into = "RawDilithium2PublicKey")] -pub struct Dilithium2PublicKey { - pub pk: Bytes<{ dilithium2::public_key_bytes() }>, -} - -#[cfg(feature = "backend-dilithium2")] -impl PublicKeyConstants for Dilithium2PublicKey { - const KTY: Kty = Kty::Pqc; - const ALG: Alg = Alg::Dilithium2; - const CRV: Crv = Crv::None; -} - -#[cfg(feature = "backend-dilithium2")] -impl From<Dilithium2PublicKey> for RawDilithium2PublicKey { - fn from(key: Dilithium2PublicKey) -> Self { - Self { - kty: Some(Dilithium2PublicKey::KTY), - alg: Some(Dilithium2PublicKey::ALG), - pk: Some(key.pk), - } - } -} - -#[cfg(feature = "backend-dilithium3")] -#[derive(Clone, Debug, Eq, PartialEq, Serialize)] -#[serde(into = "RawDilithium3PublicKey")] -pub struct Dilithium3PublicKey { - pub pk: Bytes<{ dilithium3::public_key_bytes() }>, -} - -#[cfg(feature = "backend-dilithium3")] -impl PublicKeyConstants for Dilithium3PublicKey { - const KTY: Kty = Kty::Pqc; - const ALG: Alg = Alg::Dilithium3; - const CRV: Crv = Crv::None; -} - -#[cfg(feature = "backend-dilithium3")] -impl From<Dilithium3PublicKey> for RawDilithium3PublicKey { - fn from(key: Dilithium3PublicKey) -> Self { - Self { - kty: Some(Dilithium3PublicKey::KTY), - alg: Some(Dilithium3PublicKey::ALG), - pk: Some(key.pk), - } - } -} - -#[cfg(feature = "backend-dilithium5")] -#[derive(Clone, Debug, Eq, PartialEq, Serialize)] -#[serde(into = "RawDilithium5PublicKey")] -pub struct Dilithium5PublicKey { - pub pk: Bytes<{ dilithium5::public_key_bytes() }>, -} - -#[cfg(feature = "backend-dilithium5")] -impl PublicKeyConstants for Dilithium5PublicKey { - const KTY: Kty = Kty::Pqc; - const ALG: Alg = Alg::Dilithium5; - const CRV: Crv = Crv::None; -} - -#[cfg(feature = "backend-dilithium5")] -impl From<Dilithium5PublicKey> for RawDilithium5PublicKey { - fn from(key: Dilithium5PublicKey) -> Self { - Self { - kty: Some(Dilithium5PublicKey::KTY), - alg: Some(Dilithium5PublicKey::ALG), - pk: Some(key.pk), - } - } -} - #[derive(Clone, Debug, Default, Eq, PartialEq, Serialize)] #[serde(into = "RawEcPublicKey")] pub struct TotpPublicKey {} @@ -946,44 +516,168 @@ impl<'de> serde::Deserialize<'de> for Ed25519PublicKey { } } -#[cfg(feature = "backend-dilithium2")] -impl<'de> serde::Deserialize<'de> for Dilithium2PublicKey { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> - where - D: serde::Deserializer<'de>, - { - let RawDilithium2PublicKey { kty, alg, pk, .. } = - RawDilithium2PublicKey::deserialize(deserializer)?; - check_key_constants::<Dilithium2PublicKey, D::Error>(kty, alg, Some(Crv::None))?; - let pk = pk.ok_or_else(|| D::Error::missing_field("pk"))?; - Ok(Self { pk }) - } -} +#[cfg(feature = "backend-dilithium")] +macro_rules! dilithium_public_key { + ($dilithium_number: tt) => { + paste! { + with_eager_expansions! { + #[derive(Clone, Debug, Eq, PartialEq, Serialize)] + #[serde(into = #{ concat!("RawDilithium", stringify!($dilithium_number), "PublicKey") })] + pub struct [<Dilithium $dilithium_number PublicKey>] { + pub pk: Bytes<{ [<dilithium $dilithium_number>]::public_key_bytes() }>, + } -#[cfg(feature = "backend-dilithium3")] -impl<'de> serde::Deserialize<'de> for Dilithium3PublicKey { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> - where - D: serde::Deserializer<'de>, - { - let RawDilithium3PublicKey { kty, alg, pk, .. } = - RawDilithium3PublicKey::deserialize(deserializer)?; - check_key_constants::<Dilithium3PublicKey, D::Error>(kty, alg, Some(Crv::None))?; - let pk = pk.ok_or_else(|| D::Error::missing_field("pk"))?; - Ok(Self { pk }) - } + impl PublicKeyConstants for [<Dilithium $dilithium_number PublicKey>] { + const KTY: Kty = Kty::Pqc; + const ALG: Alg = Alg::[<Dilithium $dilithium_number>]; + const CRV: Crv = Crv::None; + } + + impl From<[<Dilithium $dilithium_number PublicKey>]> for PublicKey { + fn from(key: [<Dilithium $dilithium_number PublicKey>]) -> Self { + PublicKey::[<Dilithium $dilithium_number>](key) + } + } + + #[derive(Clone, Debug, Default)] + struct [<RawDilithium $dilithium_number PublicKey>] { + kty: Option<Kty>, + alg: Option<Alg>, + pk: Option<Bytes<{ [<dilithium $dilithium_number>]::public_key_bytes() }>>, + } + + impl From<[<Dilithium $dilithium_number PublicKey>]> for [<RawDilithium $dilithium_number PublicKey>] { + fn from(key: [<Dilithium $dilithium_number PublicKey>]) -> Self { + Self { + kty: Some([<Dilithium $dilithium_number PublicKey>]::KTY), + alg: Some([<Dilithium $dilithium_number PublicKey>]::ALG), + pk: Some(key.pk), + } + } + } + + impl<'de> serde::Deserialize<'de> for [<Dilithium $dilithium_number PublicKey>] { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + let [<RawDilithium $dilithium_number PublicKey>] { kty, alg, pk, .. } = + [<RawDilithium $dilithium_number PublicKey>]::deserialize(deserializer)?; + check_key_constants::<[<Dilithium $dilithium_number PublicKey>], D::Error>(kty, alg, Some(Crv::None))?; + let pk = pk.ok_or_else(|| D::Error::missing_field("pk"))?; + Ok(Self { pk }) + } + } + + impl<'de> Deserialize<'de> for [<RawDilithium $dilithium_number PublicKey>] { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + struct IndexedVisitor; + impl<'de> serde::de::Visitor<'de> for IndexedVisitor { + type Value = [<RawDilithium $dilithium_number PublicKey>]; + + fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { + formatter.write_str(concat!("RawDilithium", stringify!($dilithium_number), "PublicKey")) + } + + fn visit_map<V>(self, mut map: V) -> Result<[<RawDilithium $dilithium_number PublicKey>], V::Error> + where + V: MapAccess<'de>, + { + #[derive(PartialEq)] + enum Key { + Label(Label), + Unknown(i8), + None, + } + + fn next_key<'a, V: MapAccess<'a>>(map: &mut V) -> Result<Key, V::Error> { + let key: Option<i8> = map.next_key()?; + let key = match key { + Some(key) => match Label::try_from(key) { + Ok(label) => Key::Label(label), + Err(_) => Key::Unknown(key), + }, + None => Key::None, + }; + Ok(key) + } + + let mut public_key = [<RawDilithium $dilithium_number PublicKey>]::default(); + + // As we cannot deserialize arbitrary values with cbor-smol, we do not support + // unknown keys before a known key. If there are unknown keys, they must be at the + // end. + + // only deserialize in canonical order + + let mut key = next_key(&mut map)?; + + if key == Key::Label(Label::Kty) { + public_key.kty = Some(map.next_value()?); + key = next_key(&mut map)?; + } + + if key == Key::Label(Label::Alg) { + public_key.alg = Some(map.next_value()?); + key = next_key(&mut map)?; + } + + if key == Key::Label(Label::CrvOrPk) { + public_key.pk = Some(map.next_value()?); + key = next_key(&mut map)?; + } + + // if there is another key, it should be an unknown one + if matches!(key, Key::Label(_)) { + Err(serde::de::Error::custom( + "public key data in wrong order or with duplicates", + )) + } else { + Ok(public_key) + } + } + } + deserializer.deserialize_map(IndexedVisitor {}) + } + } + + impl Serialize for [<RawDilithium $dilithium_number PublicKey>] { + fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + let is_set = [self.kty.is_some(), self.alg.is_some(), self.pk.is_some()]; + let fields = is_set.into_iter().map(usize::from).sum(); + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(fields))?; + + // 1: kty + if let Some(kty) = &self.kty { + map.serialize_entry(&(Label::Kty as i8), &(*kty as i8))?; + } + // 3: alg + if let Some(alg) = &self.alg { + map.serialize_entry(&(Label::Alg as i8), &(*alg as i8))?; + } + // -1: pk + if let Some(pk) = &self.pk { + map.serialize_entry(&(Label::CrvOrPk as i8), pk)?; + } + + map.end() + } + } + } + } + }; } +#[cfg(feature = "backend-dilithium2")] +dilithium_public_key!(2); +#[cfg(feature = "backend-dilithium3")] +dilithium_public_key!(3); #[cfg(feature = "backend-dilithium5")] -impl<'de> serde::Deserialize<'de> for Dilithium5PublicKey { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> - where - D: serde::Deserializer<'de>, - { - let RawDilithium5PublicKey { kty, alg, pk, .. } = - RawDilithium5PublicKey::deserialize(deserializer)?; - check_key_constants::<Dilithium5PublicKey, D::Error>(kty, alg, Some(Crv::None))?; - let pk = pk.ok_or_else(|| D::Error::missing_field("pk"))?; - Ok(Self { pk }) - } -} +dilithium_public_key!(5);