Skip to content

Commit

Permalink
Refactor JOSE portion of project out
Browse files Browse the repository at this point in the history
  • Loading branch information
tgross35 committed Oct 2, 2023
1 parent 24ac5d3 commit 1233695
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 122 deletions.
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ pub enum Error {
Json(serde_json::Error),
Jose(josekit::JoseError),
VerifyKey,
KeyType(Box<str>),
}

impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::VerifyKey => write!(f, "missing a key marked 'verify'"),
Self::KeyType(v) => write!(f, "unsupported key type {v}"),
_ => write!(f, ""),
}
}
Expand Down
132 changes: 132 additions & 0 deletions src/jose.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
use std::fmt;
use std::ops::Deref;
use std::{sync::OnceLock, time::Duration};

use crate::util::{b64_to_bytes, b64_to_str};
use crate::{Error, Result};
use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
use josekit::jwk::Jwk;
use josekit::jws::alg::ecdsa::EcdsaJwsAlgorithm;
use josekit::jws::alg::eddsa::EddsaJwsAlgorithm;
use josekit::jws::{self, JwsAlgorithm, JwsVerifier};
use serde::{Deserialize, Deserializer, Serialize};
use serde_json::{json, Value};

/// Representation of a tang advertisment response which is a JWS of available keys.
#[derive(Deserialize)]
pub struct Advertisment {
#[serde(deserialize_with = "b64_to_str")]
protected: String,
#[serde(deserialize_with = "b64_to_str")]
payload: String,
#[serde(deserialize_with = "b64_to_bytes")]
signature: Vec<u8>,
}

impl Advertisment {
/// Validate the entire advertisment. This checks the `verify` key correctly signs the data.
fn validate(&self, jwks: &JwkSet) -> Result<()> {
let verify_jwk = jwks.get_key_by_op("verify")?;
let verifier = get_verifier(verify_jwk)?;

// B64 is 4/3 data length, plus a `.`
let verify_len = ((self.payload.len() + self.protected.len()) * 4 / 3) + 1;
let mut to_verify = String::with_capacity(verify_len);

// The format `b64(HEADER).b64(PAYLOAD)` is used for validation
BASE64_URL_SAFE_NO_PAD.encode_string(&self.protected, &mut to_verify);
to_verify.push('.');
BASE64_URL_SAFE_NO_PAD.encode_string(&self.payload, &mut to_verify);

verifier
.verify(to_verify.as_bytes(), &self.signature)
.map_err(Into::into)
}

/// Validate the advertisment and extract its keys
pub fn into_keys(self) -> Result<JwkSet> {
let jwks: JwkSet = serde_json::from_str(&self.payload)?;
self.validate(&jwks)?;
Ok(jwks)
}
}

impl fmt::Debug for Advertisment {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fn json_field(s: &str) -> Box<dyn fmt::Debug + '_> {
match serde_json::from_str::<Value>(s) {
Ok(v) => Box::new(v),
Err(_) => Box::new(s),
}
}

f.debug_struct("Advertisment")
.field("payload", &json_field(&self.payload))
.field("protected", &json_field(&self.protected))
.field("signature", &BASE64_URL_SAFE_NO_PAD.encode(&self.signature))
.finish()
}
}

#[derive(Debug, Deserialize)]
pub struct JwkSet {
keys: Vec<Jwk>,
}

impl JwkSet {
/// Get a single key that contains an operation
fn get_key_by_op(&self, op_name: &str) -> Result<&Jwk> {
self.keys
.iter()
.find(|key| {
key.key_operations().map_or(false, |key_ops| {
key_ops.iter().any(|op| op.eq_ignore_ascii_case(op_name))
})
})
.ok_or(Error::MissingKeyOp(op_name.into()))
}
}

/// The key types we support
#[derive(Clone, Copy, Debug, PartialEq)]
enum KeyType {
Ec,
Rsa,
}

/// Extract the key type of a jwk
fn key_type(jwk: &Jwk) -> Result<KeyType> {
match jwk.key_type() {
"EC" => Ok(KeyType::Ec),
"RSA" => Ok(KeyType::Rsa),
_ => Err(Error::KeyType(jwk.key_type().into())),
}
}

/// Get a verifier from a JWK
fn get_verifier(jwk: &Jwk) -> Result<Box<dyn JwsVerifier>> {
let kty = key_type(jwk)?;
if kty == KeyType::Ec {
jws::ES512
.verifier_from_jwk(jwk)
.or_else(|_| jws::ES256.verifier_from_jwk(jwk))
.or_else(|_| jws::ES256K.verifier_from_jwk(jwk))
.or_else(|_| jws::ES384.verifier_from_jwk(jwk))
.map(|v| Box::new(v) as Box<dyn JwsVerifier>)
} else if kty == KeyType::Rsa {
jws::RS256
.verifier_from_jwk(jwk)
.or_else(|_| jws::RS384.verifier_from_jwk(jwk))
.or_else(|_| jws::RS512.verifier_from_jwk(jwk))
.map(|v| Box::new(v) as Box<dyn JwsVerifier>)
} else {
unreachable!()
}
.map_err(Into::into)
}

fn make_thumbprint(jwk: &Jwk) {}

#[cfg(test)]
#[path = "jose_tests.rs"]
mod tests;
8 changes: 4 additions & 4 deletions src/tests.rs → src/jose_tests.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use super::*;
use serde_json::{json, Value};

use crate::tang_interface::Advertisment;

/// Sample JWS as provided from a tang server
const SAMPLE_JWS: &str = concat!(
pub const SAMPLE_JWS: &str = concat!(
r#"{"payload": ""#,
// The payload contains `{"keys": [...]}` with the two keys below
"eyJrZXlzIjogW3siYWxnIjogIkVDTVIiLCAia3R5IjogIkVDIiwgImNydiI6ICJQLTUyMSIsICJ4IjogIkFGa3preGxGa\
Expand Down Expand Up @@ -49,6 +48,7 @@ const SAMPLE_JWK_VERIFY_NAME: &str = "wUNL__gwORwHmgKjKvVnK2rCFEWOu1oM65na-9iVcq

#[test]
fn test_verify() {
// Ensure we can extract and validate the keys
let adv: Advertisment = serde_json::from_str(SAMPLE_JWS).unwrap();
let _ = adv.validate().unwrap();
let _ = adv.into_keys().unwrap();
}
4 changes: 1 addition & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
mod decrypt;
mod encrypt;
mod error;
mod jose;
mod tang_interface;
mod util;

Expand All @@ -13,6 +14,3 @@ pub use encrypt::{EncryptConfig, EncryptSource};
pub use tang_interface::TangClient;

pub use error::{Error, Result};

#[cfg(test)]
mod tests;
118 changes: 3 additions & 115 deletions src/tang_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::fmt;
use std::ops::Deref;
use std::{sync::OnceLock, time::Duration};

use crate::jose::Advertisment;
use crate::util::{b64_to_bytes, b64_to_str};
use crate::{Error, Result};
use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
Expand Down Expand Up @@ -40,9 +41,8 @@ impl TangClient {
pub fn fetch_public_keys(&self) -> Result<()> {
let url = format!("{}/adv", &self.url);
log::debug!("fetching advertisment from '{url}'");
let keys: Advertisment = ureq::get(&url).timeout(self.timeout).call()?.into_json()?;
dbg!(&keys);
keys.validate()?;
let adv: Advertisment = ureq::get(&url).timeout(self.timeout).call()?.into_json()?;
let keys = adv.into_keys();
Ok(())
}

Expand All @@ -52,115 +52,3 @@ impl TangClient {
// /// Perform recovery
// pub fn recover_key(url: &str, key_id: String) {}
}

/// Representation of a tang advertisment response which is a JWS of available keys.
#[derive(Deserialize)]
pub struct Advertisment {
#[serde(deserialize_with = "b64_to_str")]
protected: String,
#[serde(deserialize_with = "b64_to_str")]
payload: String,
#[serde(deserialize_with = "b64_to_bytes")]
signature: Vec<u8>,
}

impl Advertisment {
/// Validate the entire advertisment. This checks the `verify` key correctly signs the data.
pub fn validate(&self) -> Result<()> {
let jwks: JwkSet = serde_json::from_str(&self.payload)?;
let verify_jwk = jwks.get_key_by_op("verify")?;
let verifier = get_verifier(verify_jwk)?;

// B64 is 4/3 data length, plus a `.`
let verify_len = ((self.payload.len() + self.protected.len()) * 4 / 3) + 1;
let mut to_verify = String::with_capacity(verify_len);

// The format `b64(HEADER).b64(PAYLOAD)` is used for validation
BASE64_URL_SAFE_NO_PAD.encode_string(&self.protected, &mut to_verify);
to_verify.push('.');
BASE64_URL_SAFE_NO_PAD.encode_string(&self.payload, &mut to_verify);

verifier
.verify(to_verify.as_bytes(), &self.signature)
.map_err(Into::into)
}

fn signature(&self) {}
}

impl fmt::Debug for Advertisment {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fn json_field(s: &str) -> Box<dyn fmt::Debug + '_> {
match serde_json::from_str::<Value>(s) {
Ok(v) => Box::new(v),
Err(_) => Box::new(s),
}
}

f.debug_struct("Advertisment")
.field("payload", &json_field(&self.payload))
.field("protected", &json_field(&self.protected))
.field("signature", &BASE64_URL_SAFE_NO_PAD.encode(&self.signature))
.finish()
}
}

#[derive(Debug, Deserialize)]
pub struct JwkSet {
keys: Vec<Jwk>,
}

impl JwkSet {
fn get_key_by_op(&self, op_name: &str) -> Result<&Jwk> {
self.keys
.iter()
.find(|key| {
key.key_operations().map_or(false, |key_ops| {
key_ops.iter().any(|op| op.eq_ignore_ascii_case(op_name))
})
})
.ok_or(Error::MissingKeyOp(op_name.into()))
}
}

/// Get a verifier from a JWK
fn get_verifier(jwk: &Jwk) -> Result<Box<dyn JwsVerifier>> {
// Start with most likely algorithms
jws::ES512
.verifier_from_jwk(jwk)
.or_else(|_| jws::ES256.verifier_from_jwk(jwk))
.or_else(|_| jws::ES256K.verifier_from_jwk(jwk))
.or_else(|_| jws::ES384.verifier_from_jwk(jwk))
.map(|v| Box::new(v) as Box<dyn JwsVerifier>)
.or_else(|_| {
// EdDSA
jws::EdDSA
.verifier_from_jwk(jwk)
.map(|v| Box::new(v) as Box<dyn JwsVerifier>)
})
.or_else(|_| {
// HMAC
jws::HS256
.verifier_from_jwk(jwk)
.or_else(|_| jws::HS384.verifier_from_jwk(jwk))
.or_else(|_| jws::HS512.verifier_from_jwk(jwk))
.map(|v| Box::new(v) as Box<dyn JwsVerifier>)
})
.or_else(|_| {
// RSA
jws::RS256
.verifier_from_jwk(jwk)
.or_else(|_| jws::RS384.verifier_from_jwk(jwk))
.or_else(|_| jws::RS512.verifier_from_jwk(jwk))
.map(|v| Box::new(v) as Box<dyn JwsVerifier>)
})
.or_else(|_| {
// RSA PSS
jws::PS256
.verifier_from_jwk(jwk)
.or_else(|_| jws::PS384.verifier_from_jwk(jwk))
.or_else(|_| jws::PS512.verifier_from_jwk(jwk))
.map(|v| Box::new(v) as Box<dyn JwsVerifier>)
})
.map_err(Into::into)
}

0 comments on commit 1233695

Please sign in to comment.