Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
Thalhammer authored May 2, 2020
2 parents bb5df74 + 6b349fd commit f62fe8b
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 98 deletions.
68 changes: 55 additions & 13 deletions include/jwt-cpp/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,26 @@
#include <string>
#include <array>

#ifdef __has_cpp_attribute
#if __has_cpp_attribute(fallthrough)
#define JWT_FALLTHROUGH [[fallthrough]]
#endif
#endif

#ifndef JWT_FALLTHROUGH
#define JWT_FALLTHROUGH
#endif

namespace jwt {
namespace alphabet {
struct base64 {
static const std::array<char, 64>& data() {
static std::array<char, 64> data = {
{'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'}};
return data;
static std::array<char, 64> data = {
{'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'}};
return data;
};
static const std::string& fill() {
static std::string fill = "=";
Expand All @@ -20,12 +30,12 @@ namespace jwt {
};
struct base64url {
static const std::array<char, 64>& data() {
static std::array<char, 64> data = {
{'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-', '_'}};
return data;
static std::array<char, 64> data = {
{'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-', '_'}};
return data;
};
static const std::string& fill() {
static std::string fill = "%3d";
Expand All @@ -44,6 +54,14 @@ namespace jwt {
static std::string decode(const std::string& base) {
return decode(base, T::data(), T::fill());
}
template<typename T>
static std::string pad(const std::string& base) {
return pad(base, T::fill());
}
template<typename T>
static std::string trim(const std::string& base) {
return trim(base, T::fill());
}

private:
static std::string encode(const std::string& bin, const std::array<char, 64>& alphabet, const std::string& fill) {
Expand Down Expand Up @@ -120,7 +138,7 @@ namespace jwt {
auto get_sextet = [&](size_t offset) {
for (size_t i = 0; i < alphabet.size(); i++) {
if (alphabet[i] == base[offset])
return i;
return static_cast<uint32_t>(i);
}
throw std::runtime_error("Invalid input");
};
Expand Down Expand Up @@ -164,5 +182,29 @@ namespace jwt {

return res;
}

static std::string pad(const std::string& base, const std::string& fill) {
std::string padding;
switch (base.size() % 4) {
case 1:
padding += fill;
JWT_FALLTHROUGH;
case 2:
padding += fill;
JWT_FALLTHROUGH;
case 3:
padding += fill;
JWT_FALLTHROUGH;
default:
break;
}

return base + padding;
}

static std::string trim(const std::string& base, const std::string& fill) {
auto pos = base.find(fill);
return base.substr(0, pos);
}
};
}
150 changes: 69 additions & 81 deletions include/jwt-cpp/jwt.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <openssl/err.h>

//If openssl version less than 1.1
#if OPENSSL_VERSION_NUMBER < 269484032
#if OPENSSL_VERSION_NUMBER < 0x10100000L
#define OPENSSL10
#endif

Expand Down Expand Up @@ -74,11 +74,10 @@ namespace jwt {
namespace helper {
inline
std::string extract_pubkey_from_cert(const std::string& certstr, const std::string& pw = "") {
// TODO: Cannot find the exact version this change happended
#if OPENSSL_VERSION_NUMBER <= 0x10100003L
std::unique_ptr<BIO, decltype(&BIO_free_all)> certbio(BIO_new_mem_buf(const_cast<char*>(certstr.data()), certstr.size()), BIO_free_all);
#else
std::unique_ptr<BIO, decltype(&BIO_free_all)> certbio(BIO_new_mem_buf(certstr.data(), certstr.size()), BIO_free_all);
std::unique_ptr<BIO, decltype(&BIO_free_all)> certbio(BIO_new_mem_buf(certstr.data(), static_cast<int>(certstr.size())), BIO_free_all);
#endif
std::unique_ptr<BIO, decltype(&BIO_free_all)> keybio(BIO_new(BIO_s_mem()), BIO_free_all);

Expand All @@ -99,10 +98,12 @@ namespace jwt {
std::unique_ptr<BIO, decltype(&BIO_free_all)> pubkey_bio(BIO_new(BIO_s_mem()), BIO_free_all);
if(key.substr(0, 27) == "-----BEGIN CERTIFICATE-----") {
auto epkey = helper::extract_pubkey_from_cert(key, password);
if ((size_t)BIO_write(pubkey_bio.get(), epkey.data(), epkey.size()) != epkey.size())
const int len = static_cast<int>(epkey.size());
if (BIO_write(pubkey_bio.get(), epkey.data(), len) != len)
throw rsa_exception("failed to load public key: bio_write failed");
} else {
if ((size_t)BIO_write(pubkey_bio.get(), key.data(), key.size()) != key.size())
const int len = static_cast<int>(key.size());
if (BIO_write(pubkey_bio.get(), key.data(), len) != len)
throw rsa_exception("failed to load public key: bio_write failed");
}

Expand All @@ -115,13 +116,41 @@ namespace jwt {
inline
std::shared_ptr<EVP_PKEY> load_private_key_from_string(const std::string& key, const std::string& password = "") {
std::unique_ptr<BIO, decltype(&BIO_free_all)> privkey_bio(BIO_new(BIO_s_mem()), BIO_free_all);
if ((size_t)BIO_write(privkey_bio.get(), key.data(), key.size()) != key.size())
const int len = static_cast<int>(key.size());
if (BIO_write(privkey_bio.get(), key.data(), len) != len)
throw rsa_exception("failed to load private key: bio_write failed");
std::shared_ptr<EVP_PKEY> pkey(PEM_read_bio_PrivateKey(privkey_bio.get(), nullptr, nullptr, const_cast<char*>(password.c_str())), EVP_PKEY_free);
if (!pkey)
throw rsa_exception("failed to load private key: PEM_read_bio_PrivateKey failed");
return pkey;
}

/**
* Convert a OpenSSL BIGNUM to a std::string
* \param bn BIGNUM to convert
* \return bignum as string
*/
inline
#ifdef OPENSSL10
static std::string bn2raw(BIGNUM* bn)
#else
static std::string bn2raw(const BIGNUM* bn)
#endif
{
std::string res;
res.resize(BN_num_bytes(bn));
BN_bn2bin(bn, (unsigned char*)res.data());
return res;
}
/**
* Convert an std::string to a OpenSSL BIGNUM
* \param raw String to convert
* \return BIGNUM representation
*/
inline
static std::unique_ptr<BIGNUM, decltype(&BN_free)> raw2bn(const std::string& raw) {
return std::unique_ptr<BIGNUM, decltype(&BN_free)>(BN_bin2bn((const unsigned char*)raw.data(), static_cast<int>(raw.size()), nullptr), BN_free);
}
}

namespace algorithm {
Expand Down Expand Up @@ -166,9 +195,9 @@ namespace jwt {
*/
std::string sign(const std::string& data) const {
std::string res;
res.resize(EVP_MAX_MD_SIZE);
unsigned int len = res.size();
if (HMAC(md(), secret.data(), secret.size(), (const unsigned char*)data.data(), data.size(), (unsigned char*)res.data(), &len) == nullptr)
res.resize(static_cast<size_t>(EVP_MAX_MD_SIZE));
unsigned int len = static_cast<unsigned int>(res.size());
if (HMAC(md(), secret.data(), static_cast<int>(secret.size()), (const unsigned char*)data.data(), static_cast<int>(data.size()), (unsigned char*)res.data(), &len) == nullptr)
throw signature_generation_exception();
res.resize(len);
return res;
Expand Down Expand Up @@ -280,7 +309,7 @@ namespace jwt {
throw signature_verification_exception("failed to verify signature: VerifyInit failed");
if (!EVP_VerifyUpdate(ctx.get(), data.data(), data.size()))
throw signature_verification_exception("failed to verify signature: VerifyUpdate failed");
auto res = EVP_VerifyFinal(ctx.get(), (const unsigned char*)signature.data(), signature.size(), pkey.get());
auto res = EVP_VerifyFinal(ctx.get(), (const unsigned char*)signature.data(), static_cast<unsigned int>(signature.size()), pkey.get());
if (res != 1)
throw signature_verification_exception("evp verify final failed: " + std::to_string(res) + " " + ERR_error_string(ERR_get_error(), NULL));
}
Expand Down Expand Up @@ -319,10 +348,12 @@ namespace jwt {
std::unique_ptr<BIO, decltype(&BIO_free_all)> pubkey_bio(BIO_new(BIO_s_mem()), BIO_free_all);
if(public_key.substr(0, 27) == "-----BEGIN CERTIFICATE-----") {
auto epkey = helper::extract_pubkey_from_cert(public_key, public_key_password);
if ((size_t)BIO_write(pubkey_bio.get(), epkey.data(), epkey.size()) != epkey.size())
const int len = static_cast<int>(epkey.size());
if (BIO_write(pubkey_bio.get(), epkey.data(), len) != len)
throw ecdsa_exception("failed to load public key: bio_write failed");
} else {
if ((size_t)BIO_write(pubkey_bio.get(), public_key.data(), public_key.size()) != public_key.size())
const int len = static_cast<int>(public_key.size());
if (BIO_write(pubkey_bio.get(), public_key.data(), len) != len)
throw ecdsa_exception("failed to load public key: bio_write failed");
}

Expand All @@ -336,7 +367,8 @@ namespace jwt {

if (!private_key.empty()) {
std::unique_ptr<BIO, decltype(&BIO_free_all)> privkey_bio(BIO_new(BIO_s_mem()), BIO_free_all);
if ((size_t)BIO_write(privkey_bio.get(), private_key.data(), private_key.size()) != private_key.size())
const int len = static_cast<int>(private_key.size());
if (BIO_write(privkey_bio.get(), private_key.data(), len) != len)
throw ecdsa_exception("failed to load private key: bio_write failed");
pkey.reset(PEM_read_bio_ECPrivateKey(privkey_bio.get(), nullptr, nullptr, const_cast<char*>(private_key_password.c_str())), EC_KEY_free);
if (!pkey)
Expand All @@ -361,19 +393,19 @@ namespace jwt {
const std::string hash = generate_hash(data);

std::unique_ptr<ECDSA_SIG, decltype(&ECDSA_SIG_free)>
sig(ECDSA_do_sign((const unsigned char*)hash.data(), hash.size(), pkey.get()), ECDSA_SIG_free);
sig(ECDSA_do_sign((const unsigned char*)hash.data(), static_cast<int>(hash.size()), pkey.get()), ECDSA_SIG_free);
if(!sig)
throw signature_generation_exception();
#ifdef OPENSSL10

auto rr = bn2raw(sig->r);
auto rs = bn2raw(sig->s);
auto rr = helper::bn2raw(sig->r);
auto rs = helper::bn2raw(sig->s);
#else
const BIGNUM *r;
const BIGNUM *s;
ECDSA_SIG_get0(sig.get(), &r, &s);
auto rr = bn2raw(r);
auto rs = bn2raw(s);
auto rr = helper::bn2raw(r);
auto rs = helper::bn2raw(s);
#endif
if(rr.size() > signature_length/2 || rs.size() > signature_length/2)
throw std::logic_error("bignum size exceeded expected length");
Expand All @@ -390,8 +422,8 @@ namespace jwt {
*/
void verify(const std::string& data, const std::string& signature) const {
const std::string hash = generate_hash(data);
auto r = raw2bn(signature.substr(0, signature.size() / 2));
auto s = raw2bn(signature.substr(signature.size() / 2));
auto r = helper::raw2bn(signature.substr(0, signature.size() / 2));
auto s = helper::raw2bn(signature.substr(signature.size() / 2));

#ifdef OPENSSL10
ECDSA_SIG sig;
Expand All @@ -405,7 +437,7 @@ namespace jwt {

ECDSA_SIG_set0(sig.get(), r.release(), s.release());

if(ECDSA_do_verify((const unsigned char*)hash.data(), hash.size(), sig.get(), pkey.get()) != 1)
if(ECDSA_do_verify((const unsigned char*)hash.data(), static_cast<int>(hash.size()), sig.get(), pkey.get()) != 1)
throw signature_verification_exception("Invalid signature");
#endif
}
Expand All @@ -417,31 +449,6 @@ namespace jwt {
return alg_name;
}
private:
/**
* Convert a OpenSSL BIGNUM to a std::string
* \param bn BIGNUM to convert
* \return bignum as string
*/
#ifdef OPENSSL10
static std::string bn2raw(BIGNUM* bn)
#else
static std::string bn2raw(const BIGNUM* bn)
#endif
{
std::string res;
res.resize(BN_num_bytes(bn));
BN_bn2bin(bn, (unsigned char*)res.data());
return res;
}
/**
* Convert an std::string to a OpenSSL BIGNUM
* \param raw String to convert
* \return BIGNUM representation
*/
static std::unique_ptr<BIGNUM, decltype(&BN_free)> raw2bn(const std::string& raw) {
return std::unique_ptr<BIGNUM, decltype(&BN_free)>(BN_bin2bn((const unsigned char*)raw.data(), raw.size(), nullptr), BN_free);
}

/**
* Hash the provided data using the hash function specified in constructor
* \param data Data to hash
Expand Down Expand Up @@ -533,7 +540,7 @@ namespace jwt {
const int size = RSA_size(key.get());

std::string sig(size, 0x00);
if(!RSA_public_decrypt(signature.size(), (const unsigned char*)signature.data(), (unsigned char*)sig.data(), key.get(), RSA_NO_PADDING))
if(!RSA_public_decrypt(static_cast<int>(signature.size()), (const unsigned char*)signature.data(), (unsigned char*)sig.data(), key.get(), RSA_NO_PADDING))
throw signature_verification_exception("Invalid signature");

if(!RSA_verify_PKCS1_PSS_mgf1(key.get(), (const unsigned char*)hash.data(), md(), md(), (const unsigned char*)sig.data(), -1))
Expand Down Expand Up @@ -1139,36 +1146,9 @@ namespace jwt {
signature = signature_base64 = token.substr(payload_end + 1);

// Fix padding: JWT requires padding to get removed
auto fix_padding = [](std::string& str) {
switch (str.size() % 4) {
case 1:
str += alphabet::base64url::fill();
#ifdef __has_cpp_attribute
#if __has_cpp_attribute(fallthrough)
[[fallthrough]];
#endif
#endif
case 2:
str += alphabet::base64url::fill();
#ifdef __has_cpp_attribute
#if __has_cpp_attribute(fallthrough)
[[fallthrough]];
#endif
#endif
case 3:
str += alphabet::base64url::fill();
#ifdef __has_cpp_attribute
#if __has_cpp_attribute(fallthrough)
[[fallthrough]];
#endif
#endif
default:
break;
}
};
fix_padding(header);
fix_padding(payload);
fix_padding(signature);
header = base::pad<alphabet::base64url>(header);
payload = base::pad<alphabet::base64url>(payload);
signature = base::pad<alphabet::base64url>(signature);

header = base::decode<alphabet::base64url>(header);
payload = base::decode<alphabet::base64url>(payload);
Expand Down Expand Up @@ -1344,10 +1324,7 @@ namespace jwt {
}

auto encode = [](const std::string& data) {
auto base = base::encode<alphabet::base64url>(data);
auto pos = base.find(alphabet::base64url::fill());
base = base.substr(0, pos);
return base;
return base::trim<alphabet::base64url>(base::encode<alphabet::base64url>(data));
};

std::string header = encode(picojson::value(obj_header).serialize());
Expand Down Expand Up @@ -1440,6 +1417,13 @@ namespace jwt {
* \return *this to allow chaining
*/
verifier& with_audience(const std::set<std::string>& aud) { return with_claim("aud", claim(aud)); }
/**
* Set an audience to check for.
* If the specified audiences is not present in the token the check fails.
* \param aud Audience to check for.
* \return *this to allow chaining
*/
verifier& with_audience(const std::string& aud) { return with_claim("aud", claim(aud)); }
/**
* Set an id to check for.
* Check is casesensitive.
Expand Down Expand Up @@ -1501,6 +1485,10 @@ namespace jwt {
throw token_verification_exception("claim " + key + " does not match expected");
}
}
else if (c.get_type() == claim::type::object) {
if( c.to_json().serialize() != jc.to_json().serialize())
throw token_verification_exception("claim " + key + " does not match expected");
}
else if (c.get_type() == claim::type::string) {
if (c.as_string() != jc.as_string())
throw token_verification_exception("claim " + key + " does not match expected");
Expand Down
Loading

0 comments on commit f62fe8b

Please sign in to comment.