Skip to content

Commit

Permalink
Serialization for Dilithium keys
Browse files Browse the repository at this point in the history
  • Loading branch information
SWilson4 committed Jan 7, 2025
1 parent a1e92f2 commit 0802e4c
Showing 1 changed file with 196 additions and 51 deletions.
247 changes: 196 additions & 51 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
//! Key Type 4 (Symmetric)
//! -1: k (key value)
//!
//! Key Type 6 (PQC)
//! -1: pk (public key value)
/*
COSE_Key = {
Expand All @@ -51,16 +53,14 @@ use serde::{
de::{Error as _, Expected, MapAccess, Unexpected},
Deserialize, Serialize,
};
#[cfg(feature = "backend-dilithium")]
use serde_big_array::BigArray;
use serde_repr::{Deserialize_repr, Serialize_repr};

#[repr(i8)]
#[derive(Clone, Debug, Eq, PartialEq, Serialize_repr, Deserialize_repr)]
enum Label {
Kty = 1,
Alg = 3,
Crv = -1,
CrvOrPk = -1,
X = -2,
Y = -3,
}
Expand All @@ -74,7 +74,7 @@ impl TryFrom<i8> for Label {
Ok(match label {
1 => Self::Kty,
3 => Self::Alg,
-1 => Self::Crv,
-1 => Self::CrvOrPk,
-2 => Self::X,
-3 => Self::Y,
_ => {
Expand All @@ -90,6 +90,8 @@ enum Kty {
Okp = 1,
Ec2 = 2,
Symmetric = 4,
#[cfg(feature = "backend-dilithium")]
Pqc = 7,
}

impl Expected for Kty {
Expand Down Expand Up @@ -212,11 +214,22 @@ impl From<Dilithium5PublicKey> for PublicKey {
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
enum CrvOrPk {
Ec(Crv),
#[cfg(feature = "backend-dilithium2")]
Dilithium2(Bytes<{ ffi::PQCLEAN_DILITHIUM2_CLEAN_CRYPTO_PUBLICKEYBYTES }>),
#[cfg(feature = "backend-dilithium3")]
Dilithium3(Bytes<{ ffi::PQCLEAN_DILITHIUM3_CLEAN_CRYPTO_PUBLICKEYBYTES }>),
#[cfg(feature = "backend-dilithium5")]
Dilithium5(Bytes<{ ffi::PQCLEAN_DILITHIUM5_CLEAN_CRYPTO_PUBLICKEYBYTES }>)
}

#[derive(Clone, Debug, Default)]
struct RawPublicKey {
kty: Option<Kty>,
alg: Option<Alg>,
crv: Option<Crv>,
crv_or_pk: Option<CrvOrPk>,
x: Option<Bytes<32>>,
y: Option<Bytes<32>>,
}
Expand Down Expand Up @@ -277,8 +290,8 @@ impl<'de> Deserialize<'de> for RawPublicKey {
key = next_key(&mut map)?;
}

if key == Key::Label(Label::Crv) {
public_key.crv = Some(map.next_value()?);
if key == Key::Label(Label::CrvOrPk) {
public_key.crv_or_pk = Some(map.next_value()?);
key = next_key(&mut map)?;
}

Expand Down Expand Up @@ -314,7 +327,7 @@ impl Serialize for RawPublicKey {
let is_set = [
self.kty.is_some(),
self.alg.is_some(),
self.crv.is_some(),
self.crv_or_pk.is_some(),
self.x.is_some(),
self.y.is_some(),
];
Expand All @@ -330,9 +343,17 @@ impl Serialize for RawPublicKey {
if let Some(alg) = &self.alg {
map.serialize_entry(&(Label::Alg as i8), &(*alg as i8))?;
}
// -1: crv
if let Some(crv) = &self.crv {
map.serialize_entry(&(Label::Crv as i8), &(*crv as i8))?;
// -1: crv or public key
if let Some(crv_or_pk) = &self.crv_or_pk {
match crv_or_pk {
CrvOrPk::Ec(crv) => map.serialize_entry(&(Label::CrvOrPk as i8), &(*crv as i8))?,
#[cfg(feature = "backend-dilithium2")]
CrvOrPk::Dilithium2(pk) => map.serialize_entry(&(Label::CrvOrPk as i8), pk)?,
#[cfg(feature = "backend-dilithium3")]
CrvOrPk::Dilithium3(pk) => map.serialize_entry(&(Label::CrvOrPk as i8), pk)?,
#[cfg(feature = "backend-dilithium5")]
CrvOrPk::Dilithium5(pk) => map.serialize_entry(&(Label::CrvOrPk as i8), pk)?,
}
}
// -2: x
if let Some(x) = &self.x {
Expand Down Expand Up @@ -371,7 +392,7 @@ impl From<P256PublicKey> for RawPublicKey {
Self {
kty: Some(P256PublicKey::KTY),
alg: Some(P256PublicKey::ALG),
crv: Some(P256PublicKey::CRV),
crv_or_pk: Some(CrvOrPk::Ec(P256PublicKey::CRV)),
x: Some(key.x),
y: Some(key.y),
}
Expand All @@ -396,7 +417,7 @@ impl From<EcdhEsHkdf256PublicKey> for RawPublicKey {
Self {
kty: Some(EcdhEsHkdf256PublicKey::KTY),
alg: Some(EcdhEsHkdf256PublicKey::ALG),
crv: Some(EcdhEsHkdf256PublicKey::CRV),
crv_or_pk: Some(CrvOrPk::Ec(EcdhEsHkdf256PublicKey::CRV)),
x: Some(key.x),
y: Some(key.y),
}
Expand All @@ -420,13 +441,94 @@ impl From<Ed25519PublicKey> for RawPublicKey {
Self {
kty: Some(Ed25519PublicKey::KTY),
alg: Some(Ed25519PublicKey::ALG),
crv: Some(Ed25519PublicKey::CRV),
crv_or_pk: Some(CrvOrPk::Ec(Ed25519PublicKey::CRV)),
x: Some(key.x),
y: None,
}
}
}

#[cfg(feature = "backend-dilithium2")]
#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
#[serde(into = "RawPublicKey")]
pub struct Dilithium2PublicKey {
pub pk: Bytes<{ ffi::PQCLEAN_DILITHIUM2_CLEAN_CRYPTO_PUBLICKEYBYTES }>,
}

#[cfg(feature = "backend-dilithium2")]
impl PublicKeyConstants for Dilithium2PublicKey {
const KTY: Kty = Kty::Pqc;
const ALG: Alg = Alg::Dilithium2;
const CRV: Crv = Crv::None;
}

#[cfg(feature = "backend-dilithium2")]
impl From<Dilithium2PublicKey> for RawPublicKey {
fn from(key: Dilithium2PublicKey) -> Self {
Self {
kty: Some(Dilithium2PublicKey::KTY),
alg: Some(Dilithium2PublicKey::ALG),
crv_or_pk: Some(CrvOrPk::Dilithium2(key.pk)),
x: None,
y: None,
}
}
}

#[cfg(feature = "backend-dilithium3")]
#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
#[serde(into = "RawPublicKey")]
pub struct Dilithium3PublicKey {
pub pk: Bytes<{ ffi::PQCLEAN_DILITHIUM3_CLEAN_CRYPTO_PUBLICKEYBYTES }>,
}

#[cfg(feature = "backend-dilithium3")]
impl PublicKeyConstants for Dilithium3PublicKey {
const KTY: Kty = Kty::Pqc;
const ALG: Alg = Alg::Dilithium3;
const CRV: Crv = Crv::None;
}

#[cfg(feature = "backend-dilithium3")]
impl From<Dilithium3PublicKey> for RawPublicKey {
fn from(key: Dilithium3PublicKey) -> Self {
Self {
kty: Some(Dilithium3PublicKey::KTY),
alg: Some(Dilithium3PublicKey::ALG),
crv_or_pk: Some(CrvOrPk::Dilithium3(key.pk)),
x: None,
y: None,
}
}
}

#[cfg(feature = "backend-dilithium5")]
#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
#[serde(into = "RawPublicKey")]
pub struct Dilithium5PublicKey {
pub pk: Bytes<{ ffi::PQCLEAN_DILITHIUM5_CLEAN_CRYPTO_PUBLICKEYBYTES }>,
}

#[cfg(feature = "backend-dilithium5")]
impl PublicKeyConstants for Dilithium5PublicKey {
const KTY: Kty = Kty::Pqc;
const ALG: Alg = Alg::Dilithium5;
const CRV: Crv = Crv::None;
}

#[cfg(feature = "backend-dilithium5")]
impl From<Dilithium5PublicKey> for RawPublicKey {
fn from(key: Dilithium5PublicKey) -> Self {
Self {
kty: Some(Dilithium5PublicKey::KTY),
alg: Some(Dilithium5PublicKey::ALG),
crv_or_pk: Some(CrvOrPk::Dilithium5(key.pk)),
x: None,
y: None,
}
}
}

#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize)]
#[serde(into = "RawPublicKey")]
pub struct TotpPublicKey {}
Expand All @@ -442,7 +544,7 @@ impl From<TotpPublicKey> for RawPublicKey {
Self {
kty: Some(TotpPublicKey::KTY),
alg: Some(TotpPublicKey::ALG),
crv: None,
crv_or_pk: None,
x: None,
y: None,
}
Expand All @@ -457,7 +559,7 @@ pub struct X25519PublicKey {
fn check_key_constants<K: PublicKeyConstants, E: serde::de::Error>(
kty: Option<Kty>,
alg: Option<Alg>,
crv: Option<Crv>,
crv_or_pk: &Option<CrvOrPk>,
) -> Result<(), E> {
let kty = kty.ok_or_else(|| E::missing_field("kty"))?;
if kty != K::KTY {
Expand All @@ -469,9 +571,19 @@ fn check_key_constants<K: PublicKeyConstants, E: serde::de::Error>(
}
}
if K::CRV != Crv::None {
let crv = crv.ok_or_else(|| E::missing_field("crv"))?;
if crv != K::CRV {
return Err(E::invalid_value(Unexpected::Signed(crv as _), &K::CRV));
let crv_or_pk = crv_or_pk.as_ref().ok_or_else(|| E::missing_field("crv_or_pk"))?;
match crv_or_pk {
CrvOrPk::Ec(crv) => {
if *crv != K::CRV {
return Err(E::invalid_value(Unexpected::Signed(*crv as _), &K::CRV));
}
},
#[cfg(feature = "backend-dilithium2")]
CrvOrPk::Dilithium2(pk) => return Err(E::invalid_type(Unexpected::Bytes(&pk as _), &K::CRV)),
#[cfg(feature = "backend-dilithium3")]
CrvOrPk::Dilithium3(pk) => return Err(E::invalid_type(Unexpected::Bytes(&pk as _), &K::CRV)),
#[cfg(feature = "backend-dilithium5")]
CrvOrPk::Dilithium5(pk) => return Err(E::invalid_type(Unexpected::Bytes(&pk as _), &K::CRV)),
}
}
Ok(())
Expand All @@ -485,11 +597,11 @@ impl<'de> serde::Deserialize<'de> for P256PublicKey {
let RawPublicKey {
kty,
alg,
crv,
crv_or_pk,
x,
y,
} = RawPublicKey::deserialize(deserializer)?;
check_key_constants::<P256PublicKey, D::Error>(kty, alg, crv)?;
check_key_constants::<P256PublicKey, D::Error>(kty, alg, &crv_or_pk)?;
let x = x.ok_or_else(|| D::Error::missing_field("x"))?;
let y = y.ok_or_else(|| D::Error::missing_field("y"))?;
Ok(Self { x, y })
Expand All @@ -504,11 +616,11 @@ impl<'de> serde::Deserialize<'de> for EcdhEsHkdf256PublicKey {
let RawPublicKey {
kty,
alg,
crv,
crv_or_pk,
x,
y,
} = RawPublicKey::deserialize(deserializer)?;
check_key_constants::<EcdhEsHkdf256PublicKey, D::Error>(kty, alg, crv)?;
check_key_constants::<EcdhEsHkdf256PublicKey, D::Error>(kty, alg, &crv_or_pk)?;
let x = x.ok_or_else(|| D::Error::missing_field("x"))?;
let y = y.ok_or_else(|| D::Error::missing_field("y"))?;
Ok(Self { x, y })
Expand All @@ -521,43 +633,76 @@ impl<'de> serde::Deserialize<'de> for Ed25519PublicKey {
D: serde::Deserializer<'de>,
{
let RawPublicKey {
kty, alg, crv, x, ..
kty, alg, crv_or_pk, x, ..
} = RawPublicKey::deserialize(deserializer)?;
check_key_constants::<Ed25519PublicKey, D::Error>(kty, alg, crv)?;
check_key_constants::<Ed25519PublicKey, D::Error>(kty, alg, &crv_or_pk)?;
let x = x.ok_or_else(|| D::Error::missing_field("x"))?;
Ok(Self { x })
}
}

#[cfg(feature = "backend-dilithium")]
macro_rules! dilithium_public_key {
($type_name: ident, $size: expr) => {
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct $type_name(#[serde(with = "BigArray")] [u8; $size]);
};
}

cfg_if::cfg_if! {
if #[cfg(feature = "backend-dilithium2")] {
dilithium_public_key!(
Dilithium2PublicKey,
ffi::PQCLEAN_DILITHIUM2_CLEAN_CRYPTO_PUBLICKEYBYTES
);
#[cfg(feature = "backend-dilithium2")]
impl<'de> serde::Deserialize<'de> for Dilithium2PublicKey {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let RawPublicKey {
kty, alg, crv_or_pk, ..
} = RawPublicKey::deserialize(deserializer)?;
check_key_constants::<Dilithium2PublicKey, D::Error>(kty, alg, &crv_or_pk)?;
let pk = crv_or_pk.ok_or_else(|| D::Error::missing_field("pk"))?;
match pk {
CrvOrPk::Dilithium2(pk) => Ok(Self { pk }),
#[cfg(feature = "backend-dilithium3")]
CrvOrPk::Dilithium3(pk) => Err(D::Error::invalid_length(pk.len(), &"expected 1312")),
#[cfg(feature = "backend-dilithium5")]
CrvOrPk::Dilithium5(pk) => Err(D::Error::invalid_length(pk.len(), &"expected 1312")),
CrvOrPk::Ec(crv) => return Err(D::Error::invalid_type(Unexpected::Signed(crv as _), &"expected Bytes")),
}
}
}
cfg_if::cfg_if! {
if #[cfg(feature = "backend-dilithium3")] {
dilithium_public_key!(
Dilithium3PublicKey,
ffi::PQCLEAN_DILITHIUM3_CLEAN_CRYPTO_PUBLICKEYBYTES
);

#[cfg(feature = "backend-dilithium3")]
impl<'de> serde::Deserialize<'de> for Dilithium3PublicKey {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let RawPublicKey {
kty, alg, crv_or_pk, ..
} = RawPublicKey::deserialize(deserializer)?;
check_key_constants::<Dilithium3PublicKey, D::Error>(kty, alg, &crv_or_pk)?;
let pk = crv_or_pk.ok_or_else(|| D::Error::missing_field("pk"))?;
match pk {
CrvOrPk::Dilithium3(pk) => Ok(Self { pk }),
#[cfg(feature = "backend-dilithium2")]
CrvOrPk::Dilithium2(pk) => Err(D::Error::invalid_length(pk.len(), &"expected 1952")),
#[cfg(feature = "backend-dilithium5")]
CrvOrPk::Dilithium5(pk) => Err(D::Error::invalid_length(pk.len(), &"expected 1952")),
CrvOrPk::Ec(crv) => return Err(D::Error::invalid_type(Unexpected::Signed(crv as _), &"expected Bytes")),
}
}
}
cfg_if::cfg_if! {
if #[cfg(feature = "backend-dilithium5")] {
dilithium_public_key!(
Dilithium5PublicKey,
ffi::PQCLEAN_DILITHIUM5_CLEAN_CRYPTO_PUBLICKEYBYTES
);

#[cfg(feature = "backend-dilithium5")]
impl<'de> serde::Deserialize<'de> for Dilithium5PublicKey {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let RawPublicKey {
kty, alg, crv_or_pk, ..
} = RawPublicKey::deserialize(deserializer)?;
check_key_constants::<Dilithium5PublicKey, D::Error>(kty, alg, &crv_or_pk)?;
let pk = crv_or_pk.ok_or_else(|| D::Error::missing_field("pk"))?;
match pk {
CrvOrPk::Dilithium5(pk) => Ok(Self { pk }),
#[cfg(feature = "backend-dilithium2")]
CrvOrPk::Dilithium2(pk) => Err(D::Error::invalid_length(pk.len(), &"expected 2592")),
#[cfg(feature = "backend-dilithium3")]
CrvOrPk::Dilithium3(pk) => Err(D::Error::invalid_length(pk.len(), &"expected 2592")),
CrvOrPk::Ec(crv) => return Err(D::Error::invalid_type(Unexpected::Signed(crv as _), &"expected Bytes")),
}
}
}

0 comments on commit 0802e4c

Please sign in to comment.