Skip to content

feat: allow users to specify a custom header not defined in the struct #420

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 4, 2025
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ If you want to set the `kid` parameter or change the algorithm for example:
```rust
let mut header = Header::new(Algorithm::HS512);
header.kid = Some("blabla".to_owned());

let mut extras = HashMap::with_capacity(1);
extras.insert("custom".to_string(), "header".to_string());
header.extras = Some(extras);

let token = encode(&header, &my_claims, &EncodingKey::from_secret("secret".as_ref()))?;
```
Look at `examples/custom_header.rs` for a full working example.
Expand Down
15 changes: 14 additions & 1 deletion benches/jwt.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
struct Claims {
Expand All @@ -17,6 +18,18 @@ fn bench_encode(c: &mut Criterion) {
});
}

fn bench_encode_custom_extra_headers(c: &mut Criterion) {
let claim = Claims { sub: "[email protected]".to_owned(), company: "ACME".to_owned() };
let key = EncodingKey::from_secret("secret".as_ref());
let mut extras = HashMap::with_capacity(1);
extras.insert("custom".to_string(), "header".to_string());
let header = &Header { extras, ..Default::default() };

c.bench_function("bench_encode", |b| {
b.iter(|| encode(black_box(header), black_box(&claim), black_box(&key)))
});
}

fn bench_decode(c: &mut Criterion) {
let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ";
let key = DecodingKey::from_secret("secret".as_ref());
Expand All @@ -32,5 +45,5 @@ fn bench_decode(c: &mut Criterion) {
});
}

criterion_group!(benches, bench_encode, bench_decode);
criterion_group!(benches, bench_encode, bench_encode_custom_extra_headers, bench_decode);
criterion_main!(benches);
12 changes: 10 additions & 2 deletions examples/custom_header.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

use jsonwebtoken::errors::ErrorKind;
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
Expand All @@ -15,8 +16,15 @@ fn main() {
Claims { sub: "[email protected]".to_owned(), company: "ACME".to_owned(), exp: 10000000000 };
let key = b"secret";

let header =
Header { kid: Some("signing_key".to_owned()), alg: Algorithm::HS512, ..Default::default() };
let mut extras = HashMap::with_capacity(1);
extras.insert("custom".to_string(), "header".to_string());

let header = Header {
kid: Some("signing_key".to_owned()),
alg: Algorithm::HS512,
extras,
..Default::default()
};

let token = match encode(&header, &my_claims, &EncodingKey::from_secret(key)) {
Ok(t) => t,
Expand Down
10 changes: 9 additions & 1 deletion src/header.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::result;

use base64::{engine::general_purpose::STANDARD, Engine};
Expand All @@ -10,7 +11,7 @@ use crate::serialization::b64_decode;

/// A basic JWT header, the alg defaults to HS256 and typ is automatically
/// set to `JWT`. All the other fields are optional.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Header {
/// The type of JWS: it can only be "JWT" here
///
Expand Down Expand Up @@ -64,6 +65,12 @@ pub struct Header {
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "x5t#S256")]
pub x5t_s256: Option<String>,

/// Any additional non-standard headers not defined in [RFC7515#4.1](https://datatracker.ietf.org/doc/html/rfc7515#section-4.1).
/// Once serialized, all keys will be converted to fields at the root level of the header payload
/// Ex: Dict("custom" -> "header") will be converted to "{"typ": "JWT", ..., "custom": "header"}"
#[serde(flatten)]
pub extras: HashMap<String, String>,
}

impl Header {
Expand All @@ -80,6 +87,7 @@ impl Header {
x5c: None,
x5t: None,
x5t_s256: None,
extras: Default::default(),
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/jwk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl<'de> Deserialize<'de> for PublicKeyUse {
D: Deserializer<'de>,
{
struct PublicKeyUseVisitor;
impl<'de> de::Visitor<'de> for PublicKeyUseVisitor {
impl de::Visitor<'_> for PublicKeyUseVisitor {
type Value = PublicKeyUse;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down Expand Up @@ -116,7 +116,7 @@ impl<'de> Deserialize<'de> for KeyOperations {
D: Deserializer<'de>,
{
struct KeyOperationsVisitor;
impl<'de> de::Visitor<'de> for KeyOperationsVisitor {
impl de::Visitor<'_> for KeyOperationsVisitor {
type Value = KeyOperations;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down
2 changes: 2 additions & 0 deletions src/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ use serde::{Deserialize, Serialize};

use crate::errors::Result;

#[inline]
pub(crate) fn b64_encode<T: AsRef<[u8]>>(input: T) -> String {
URL_SAFE_NO_PAD.encode(input)
}

#[inline]
pub(crate) fn b64_decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>> {
URL_SAFE_NO_PAD.decode(input).map_err(|e| e.into())
}
Expand Down
16 changes: 8 additions & 8 deletions src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,13 +337,20 @@ where
{
struct NumericType(PhantomData<fn() -> TryParse<u64>>);

impl<'de> Visitor<'de> for NumericType {
impl Visitor<'_> for NumericType {
type Value = TryParse<u64>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("A NumericType that can be reasonably coerced into a u64")
}

fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
Ok(TryParse::Parsed(value))
}

fn visit_f64<E>(self, value: f64) -> std::result::Result<Self::Value, E>
where
E: de::Error,
Expand All @@ -354,13 +361,6 @@ where
Err(serde::de::Error::custom("NumericType must be representable as a u64"))
}
}

fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
Ok(TryParse::Parsed(value))
}
}

match deserializer.deserialize_any(NumericType(PhantomData)) {
Expand Down
70 changes: 70 additions & 0 deletions tests/hmac.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use jsonwebtoken::{
decode, decode_header, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use time::OffsetDateTime;
use wasm_bindgen_test::wasm_bindgen_test;

Expand Down Expand Up @@ -51,6 +52,56 @@ fn encode_with_custom_header() {
.unwrap();
assert_eq!(my_claims, token_data.claims);
assert_eq!("kid", token_data.header.kid.unwrap());
assert!(token_data.header.extras.is_empty());
}

#[test]
#[wasm_bindgen_test]
fn encode_with_extra_custom_header() {
let my_claims = Claims {
sub: "[email protected]".to_string(),
company: "ACME".to_string(),
exp: OffsetDateTime::now_utc().unix_timestamp() + 10000,
};
let mut extras = HashMap::with_capacity(1);
extras.insert("custom".to_string(), "header".to_string());
let header = Header { kid: Some("kid".to_string()), extras, ..Default::default() };
let token = encode(&header, &my_claims, &EncodingKey::from_secret(b"secret")).unwrap();
let token_data = decode::<Claims>(
&token,
&DecodingKey::from_secret(b"secret"),
&Validation::new(Algorithm::HS256),
)
.unwrap();
assert_eq!(my_claims, token_data.claims);
assert_eq!("kid", token_data.header.kid.unwrap());
assert_eq!("header", token_data.header.extras.get("custom").unwrap().as_str());
}

#[test]
#[wasm_bindgen_test]
fn encode_with_multiple_extra_custom_headers() {
let my_claims = Claims {
sub: "[email protected]".to_string(),
company: "ACME".to_string(),
exp: OffsetDateTime::now_utc().unix_timestamp() + 10000,
};
let mut extras = HashMap::with_capacity(2);
extras.insert("custom1".to_string(), "header1".to_string());
extras.insert("custom2".to_string(), "header2".to_string());
let header = Header { kid: Some("kid".to_string()), extras, ..Default::default() };
let token = encode(&header, &my_claims, &EncodingKey::from_secret(b"secret")).unwrap();
let token_data = decode::<Claims>(
&token,
&DecodingKey::from_secret(b"secret"),
&Validation::new(Algorithm::HS256),
)
.unwrap();
assert_eq!(my_claims, token_data.claims);
assert_eq!("kid", token_data.header.kid.unwrap());
let extras = token_data.header.extras;
assert_eq!("header1", extras.get("custom1").unwrap().as_str());
assert_eq!("header2", extras.get("custom2").unwrap().as_str());
}

#[test]
Expand Down Expand Up @@ -86,6 +137,25 @@ fn decode_token() {
claims.unwrap();
}

#[test]
#[wasm_bindgen_test]
fn decode_token_custom_headers() {
let token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsImN1c3RvbTEiOiJoZWFkZXIxIiwiY3VzdG9tMiI6ImhlYWRlcjIifQ.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjI1MzI1MjQ4OTF9.FtOHsoKcNH3SriK3tnR-uWJg4UV4FkOzvq_JCfLngfU";
let claims = decode::<Claims>(
token,
&DecodingKey::from_secret(b"secret"),
&Validation::new(Algorithm::HS256),
)
.unwrap();
let my_claims =
Claims { sub: "[email protected]".to_string(), company: "ACME".to_string(), exp: 2532524891 };
assert_eq!(my_claims, claims.claims);
assert_eq!("kid", claims.header.kid.unwrap());
let extras = claims.header.extras;
assert_eq!("header1", extras.get("custom1").unwrap().as_str());
assert_eq!("header2", extras.get("custom2").unwrap().as_str());
}

#[test]
#[wasm_bindgen_test]
#[should_panic(expected = "InvalidToken")]
Expand Down
Loading