Skip to content

Commit

Permalink
Use EllSwift for handshake pubkeys and ECDH
Browse files Browse the repository at this point in the history
  • Loading branch information
Sjors committed Jan 9, 2024
1 parent 18e22c9 commit a1f476f
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 120 deletions.
135 changes: 61 additions & 74 deletions src/common/sv2_noise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <crypto/chacha20poly1305.h>
#include <crypto/hmac_sha256.h>
#include <crypto/poly1305.h>
#include <random.h>
#include <logging.h>
#include <util/check.h>
#include <util/strencodings.h>
Expand Down Expand Up @@ -120,7 +121,7 @@ void Sv2SymmetricState::MixHash(const Span<const std::byte> input)
m_hash_output = (HashWriter{} << m_hash_output << input).GetSHA256();
}

void Sv2SymmetricState::MixKey(const Span<const uint8_t> input_key_material)
void Sv2SymmetricState::MixKey(const Span<const std::byte> input_key_material)
{
uint8_t out0[KEY_SIZE], out1[KEY_SIZE];

Expand All @@ -136,11 +137,11 @@ void Sv2SymmetricState::LogChainingKey()
LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Chaining key: %s\n", HexStr(m_chaining_key));
}

void Sv2SymmetricState::HKDF2(const Span<const uint8_t> input_key_material, uint8_t out0[KEY_SIZE], uint8_t out1[KEY_SIZE])
void Sv2SymmetricState::HKDF2(const Span<const std::byte> input_key_material, uint8_t out0[KEY_SIZE], uint8_t out1[KEY_SIZE])
{
uint8_t tmp_key[KEY_SIZE];
CHMAC_SHA256 tmp_mac(m_chaining_key, KEY_SIZE);
tmp_mac.Write(input_key_material.begin(), input_key_material.size());
tmp_mac.Write(UCharCast(input_key_material.data()), input_key_material.size());
tmp_mac.Finalize(tmp_key);

CHMAC_SHA256 out0_mac(tmp_key, KEY_SIZE);
Expand Down Expand Up @@ -179,7 +180,7 @@ std::array<Sv2CipherState, 2> Sv2SymmetricState::Split()
{
uint8_t send_key[KEY_SIZE], recv_key[KEY_SIZE];

std::vector<uint8_t> empty;
std::vector<std::byte> empty;
HKDF2(empty, send_key, recv_key);

std::array<Sv2CipherState, 2> result;
Expand All @@ -189,43 +190,43 @@ std::array<Sv2CipherState, 2> Sv2SymmetricState::Split()
return result;
}

void Sv2HandshakeState::GenerateEphemeralKey(CKey& key) noexcept
Sv2HandshakeState::Sv2HandshakeState(CKey&& static_key): m_static_key{static_key}
{
m_our_static_ellswift_pk = static_key.EllSwiftCreate(MakeByteSpan(GetRandHash()));
};


void Sv2HandshakeState::GenerateEphemeralKey() noexcept
{
Assume(!key.size());
key.MakeNewKey(true);
Assume(XOnlyPubKey(key.GetPubKey()).IsFullyValid());
Assume(!m_ephemeral_key.size());
m_ephemeral_key.MakeNewKey(true);
m_our_ephemeral_ellswift_pk = m_ephemeral_key.EllSwiftCreate(MakeByteSpan(GetRandHash()));
};

void Sv2HandshakeState::WriteMsgEphemeralPK(Span<std::byte> msg)
{
if (msg.size() < KEY_SIZE) {
throw std::runtime_error(strprintf("Invalid message size: %d bytes < %d", msg.size(), KEY_SIZE));
if (msg.size() < ELLSWIFT_KEY_SIZE) {
throw std::runtime_error(strprintf("Invalid message size: %d bytes < %d", msg.size(), ELLSWIFT_KEY_SIZE));
}

GenerateEphemeralKey(m_ephemeral_key);
GenerateEphemeralKey();

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Write our ephemeral key\n");
std::copy(m_our_ephemeral_ellswift_pk.begin(), m_our_ephemeral_ellswift_pk.end(), msg.begin());

auto ephemeral_pk = XOnlyPubKey(m_ephemeral_key.GetPubKey());
std::transform(ephemeral_pk.begin(), ephemeral_pk.end(), msg.begin(),
[](unsigned char b) { return static_cast<std::byte>(b); });

m_symmetric_state.MixHash(Span(msg.begin(), KEY_SIZE));
m_symmetric_state.MixHash(Span(msg.begin(), ELLSWIFT_KEY_SIZE));
LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Mix hash: %s\n", HexStr(m_symmetric_state.m_hash_output));

std::vector<std::byte> empty;
m_symmetric_state.MixHash(empty);
}

void Sv2HandshakeState::ReadMsgEphemeralPK(Span<std::byte> msg) {
LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Write their ephemeral key\n");
auto ucharSpan = UCharSpanCast(msg);
m_remote_ephemeral_key = XOnlyPubKey(Span(&ucharSpan[0], KEY_SIZE));
LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Read their ephemeral key\n");
Assume(msg.size() == ELLSWIFT_KEY_SIZE);
m_remote_ephemeral_ellswift_pk = EllSwiftPubKey(msg);

if (!m_remote_ephemeral_key.IsFullyValid()) {
throw std::runtime_error("Sv2HandshakeState::ReadMsgEphemeralPK(): Received invalid remote ephemeral key");
}
m_symmetric_state.MixHash(Span(&msg[0], KEY_SIZE));
m_symmetric_state.MixHash(Span(&msg[0], ELLSWIFT_KEY_SIZE));
LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Mix hash: %s\n", HexStr(m_symmetric_state.m_hash_output));

std::vector<std::byte> empty;
Expand All @@ -236,50 +237,43 @@ void Sv2HandshakeState::WriteMsgES(Span<std::byte> msg)
{
ssize_t bytes_written = 0;

Assume(m_remote_ephemeral_key.IsFullyValid());

GenerateEphemeralKey(m_ephemeral_key);
GenerateEphemeralKey();

// Send our ephemeral pk.
LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Write our ephemeral key\n");
auto ephemeral_pk = XOnlyPubKey(m_ephemeral_key.GetPubKey());
Assume(ephemeral_pk.IsFullyValid());
std::transform(ephemeral_pk.begin(), ephemeral_pk.end(), msg.begin(),
[](unsigned char b) { return static_cast<std::byte>(b); });
std::copy(m_our_ephemeral_ellswift_pk.begin(), m_our_ephemeral_ellswift_pk.end(), msg.begin());

m_symmetric_state.MixHash(Span(msg.begin(), KEY_SIZE));
bytes_written += KEY_SIZE;
m_symmetric_state.MixHash(Span(msg.begin(), ELLSWIFT_KEY_SIZE));
bytes_written += ELLSWIFT_KEY_SIZE;

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Mix hash: %s\n", HexStr(m_symmetric_state.m_hash_output));

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Perform ECDH with the remote ephemeral key\n");
uint8_t ecdh_output[ECDH_OUTPUT_SIZE] = {};
if (!m_ephemeral_key.ECDH(m_remote_ephemeral_key, ecdh_output)) {
throw std::runtime_error("Failed to perform ECDH on the remote ephemeral key using our ephemeral key");
}
ECDHSecret ecdh_secret{m_ephemeral_key.ComputeBIP324ECDHSecret(m_remote_ephemeral_ellswift_pk,
m_our_ephemeral_ellswift_pk,
/*initiating=*/false)};

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Mix key with ECDH result: ephemeral ours -- remote ephemeral\n");
m_symmetric_state.MixKey(Span(ecdh_output));
m_symmetric_state.MixKey(ecdh_secret);
m_symmetric_state.LogChainingKey();

// Send our static pk.
LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Encrypt and write our static key\n");
auto static_pk = XOnlyPubKey(m_static_key.GetPubKey());
Assume(static_pk.IsFullyValid());
std::transform(static_pk.begin(), static_pk.end(), msg.begin() + KEY_SIZE,
[](unsigned char b) { return static_cast<std::byte>(b); });
m_symmetric_state.EncryptAndHash(Span(msg.begin() + KEY_SIZE, KEY_SIZE + POLY1305_TAGLEN));
bytes_written += KEY_SIZE + POLY1305_TAGLEN;
std::copy(m_our_static_ellswift_pk.begin(), m_our_static_ellswift_pk.end(), msg.begin() + ELLSWIFT_KEY_SIZE);

m_symmetric_state.EncryptAndHash(Span(msg.begin() + ELLSWIFT_KEY_SIZE, ELLSWIFT_KEY_SIZE + POLY1305_TAGLEN));
bytes_written += ELLSWIFT_KEY_SIZE + POLY1305_TAGLEN;

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Mix hash: %s\n", HexStr(m_symmetric_state.m_hash_output));

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Perform ECDH between our static and remote ephemeral key\n");
uint8_t ecdh_output_remote[ECDH_OUTPUT_SIZE];
if (!m_static_key.ECDH(m_remote_ephemeral_key, ecdh_output_remote)) {
throw std::runtime_error("Failed to perform ECDH on the remote ephemeral key using our static key");
}
LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "ECDH result: %s\n", HexStr(ecdh_output_remote));
ECDHSecret ecdh_static_secret{m_static_key.ComputeBIP324ECDHSecret(m_remote_ephemeral_ellswift_pk,
m_our_static_ellswift_pk,
/*initiating=*/false)};
LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "ECDH result: %s\n", HexStr(ecdh_static_secret));

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Mix key with ECDH result: static ours -- remote ephemeral\n");
m_symmetric_state.MixKey(Span(ecdh_output_remote));
m_symmetric_state.MixKey(ecdh_static_secret);
m_symmetric_state.LogChainingKey();

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Generate certificate\n");
Expand Down Expand Up @@ -317,48 +311,41 @@ bool Sv2HandshakeState::ReadMsgES(Span<std::byte> msg)

// Read the remote ephmeral key from the msg and decrypt.
LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Read remote ephemeral key\n");
auto remote_ephemeral_key_span = UCharSpanCast(Span(msg.begin(), KEY_SIZE));
m_remote_ephemeral_key = XOnlyPubKey(remote_ephemeral_key_span);
if (!m_remote_ephemeral_key.IsFullyValid()) {
throw std::runtime_error("Sv2HandshakeState::ReadMsgES(): Received invalid remote ephemeral key");
}
bytes_read += KEY_SIZE;
auto remote_ephemeral_key_span = Span(msg.begin(), ELLSWIFT_KEY_SIZE);
m_remote_ephemeral_ellswift_pk = EllSwiftPubKey(remote_ephemeral_key_span);
bytes_read += ELLSWIFT_KEY_SIZE;

m_symmetric_state.MixHash(Span(msg.begin(), KEY_SIZE));
m_symmetric_state.MixHash(Span(msg.begin(), ELLSWIFT_KEY_SIZE));
LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Mix hash: %s\n", HexStr(m_symmetric_state.m_hash_output));

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Perform ECDH with the remote ephemeral key\n");
uint8_t ecdh_output[ECDH_OUTPUT_SIZE];
m_ephemeral_key.ECDH(m_remote_ephemeral_key, ecdh_output);
ECDHSecret ecdh_secret{m_ephemeral_key.ComputeBIP324ECDHSecret(m_remote_ephemeral_ellswift_pk,
m_our_ephemeral_ellswift_pk,
/*initiating=*/true)};

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Mix key with ECDH result: ephemeral ours -- remote ephemeral\n");
m_symmetric_state.MixKey(Span(ecdh_output));
m_symmetric_state.MixKey(ecdh_secret);
m_symmetric_state.LogChainingKey();

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Decrypt remote static key\n");
bool res = m_symmetric_state.DecryptAndHash(Span(msg.begin() + KEY_SIZE, KEY_SIZE + POLY1305_TAGLEN));
bool res = m_symmetric_state.DecryptAndHash(Span(msg.begin() + ELLSWIFT_KEY_SIZE, ELLSWIFT_KEY_SIZE + POLY1305_TAGLEN));
if (!res) return false;
bytes_read += KEY_SIZE + POLY1305_TAGLEN;
bytes_read += ELLSWIFT_KEY_SIZE + POLY1305_TAGLEN;

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Mix hash: %s\n", HexStr(m_symmetric_state.m_hash_output));

// Load remote static key from the decryted msg
auto remote_static_key_span = UCharSpanCast(Span(msg.begin() + KEY_SIZE, KEY_SIZE));
m_remote_static_key = XOnlyPubKey(remote_static_key_span);

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Check if remote static key is valid\n");
if (!m_remote_static_key.IsFullyValid()) {
throw std::runtime_error("Sv2HandshakeState::ReadMsgES(): Received invalid remote static key");
}
auto remote_static_key_span = Span(msg.begin() + ELLSWIFT_KEY_SIZE, ELLSWIFT_KEY_SIZE);
m_remote_static_ellswift_pk = EllSwiftPubKey(remote_static_key_span);

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Perform ECDH on the remote static key\n");
uint8_t ecdh_output_remote[ECDH_OUTPUT_SIZE];
if (!m_ephemeral_key.ECDH(m_remote_static_key, ecdh_output_remote)) {
throw std::runtime_error("Failed to perform ECDH on the remote static key using our ephemeral key");
}
LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "ECDH result: %s\n", HexStr(ecdh_output_remote));
ECDHSecret ecdh_static_secret{m_ephemeral_key.ComputeBIP324ECDHSecret(m_remote_static_ellswift_pk,
m_our_ephemeral_ellswift_pk,
/*initiating=*/true)};
LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "ECDH result: %s\n", HexStr(ecdh_static_secret));

LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Mix key with ECDH result: ephemeral ours -- remote static\n");
m_symmetric_state.MixKey(Span(ecdh_output_remote));
m_symmetric_state.MixKey(ecdh_static_secret);
m_symmetric_state.LogChainingKey();


Expand Down
30 changes: 17 additions & 13 deletions src/common/sv2_noise.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,24 @@

static constexpr size_t POLY1305_TAGLEN{16};
static constexpr size_t KEY_SIZE = 32;
static constexpr size_t ELLSWIFT_KEY_SIZE = 64;
static constexpr size_t ECDH_OUTPUT_SIZE = 32;
/** Section 3: All Noise messages are less than or equal to 65535 bytes in length. */
static constexpr size_t NOISE_MAX_CHUNK_SIZE = 65535;
/** Sv2 spec 4.5.2 */
static constexpr size_t SIGNATURE_NOISE_MESSAGE_SIZE = 2 + 4 + 4 + 64;
static constexpr size_t INITIATOR_EXPECTED_HANDSHAKE_MESSAGE_LENGTH = KEY_SIZE + KEY_SIZE +
static constexpr size_t INITIATOR_EXPECTED_HANDSHAKE_MESSAGE_LENGTH = ELLSWIFT_KEY_SIZE + ELLSWIFT_KEY_SIZE +
POLY1305_TAGLEN + SIGNATURE_NOISE_MESSAGE_SIZE + POLY1305_TAGLEN;

// Sha256 hash of the ascii encoding - "Noise_NX_secp256k1_ChaChaPoly_SHA256".
// Sha256 hash of the ascii encoding - "Noise_NX_EllSwiftXonly_ChaChaPoly_SHA256".
// This is the first step required when setting up the chaining key.
const std::vector<uint8_t> PROTOCOL_NAME_HASH = {
168, 246, 65, 106, 218, 197, 235, 205, 62, 183, 118, 131, 234, 247, 6, 174, 180, 164, 162, 125,
30, 121, 156, 182, 95, 117, 218, 138, 122, 135, 4, 65,
};
27, 97, 156, 90, 248, 120, 254, 68, 34, 119, 45, 129, 209, 41, 152, 82,
26,137, 97, 115, 62, 44, 177, 60, 145, 24, 250, 214, 68, 188, 1, 128};

// The double hash of protocol name "Noise_NX_secp256k1_ChaChaPoly_SHA256".
static std::vector<uint8_t> PROTOCOL_NAME_DOUBLE_HASH = {132, 175, 109, 74, 47, 106, 167, 237, 124, 169, 128, 188, 123, 69, 19, 92, 215, 4, 100, 205, 0, 191, 211, 210, 38, 190, 247, 183, 20, 200, 116, 58};
// The double hash of protocol name "Noise_NX_EllSwiftXonly_ChaChaPoly_SHA256".
static std::vector<uint8_t> PROTOCOL_NAME_DOUBLE_HASH = {60, 102, 112, 143, 69, 248, 185, 34, 53, 193, 3, 46, 250, 104, 70, 171,
139, 103, 55, 191, 199, 9, 77, 179, 99, 170, 7, 240, 219, 36, 226, 71};

class Sv2SignatureNoiseMessage
{
Expand Down Expand Up @@ -108,7 +109,7 @@ class Sv2SymmetricState
}

void MixHash(const Span<const std::byte> input);
void MixKey(const Span<const uint8_t> input_key_material);
void MixKey(const Span<const std::byte> input_key_material);
void EncryptAndHash(Span<std::byte> data);
[[ nodiscard ]] bool DecryptAndHash(Span<std::byte> data);
std::array<Sv2CipherState, 2> Split();
Expand All @@ -119,7 +120,7 @@ class Sv2SymmetricState
private:
Sv2CipherState m_cipher_state;

void HKDF2(const Span<const uint8_t> input_key_material, uint8_t out0[KEY_SIZE], uint8_t out1[KEY_SIZE]);
void HKDF2(const Span<const std::byte> input_key_material, uint8_t out0[KEY_SIZE], uint8_t out1[KEY_SIZE]);

};

Expand Down Expand Up @@ -178,13 +179,15 @@ class Sv2HandshakeState
{
public:
CKey m_static_key;
XOnlyPubKey m_remote_ephemeral_key;
XOnlyPubKey m_remote_static_key;
CKey m_ephemeral_key;
EllSwiftPubKey m_our_static_ellswift_pk;
EllSwiftPubKey m_our_ephemeral_ellswift_pk;
EllSwiftPubKey m_remote_ephemeral_ellswift_pk;
EllSwiftPubKey m_remote_static_ellswift_pk;
Sv2SymmetricState m_symmetric_state;

Sv2HandshakeState() = default;
Sv2HandshakeState(CKey&& static_key): m_static_key{static_key} {};
Sv2HandshakeState(CKey&& static_key);

void WriteMsgEphemeralPK(Span<std::byte> msg);
void ReadMsgEphemeralPK(Span<std::byte> msg);
Expand All @@ -198,7 +201,8 @@ class Sv2HandshakeState
[[nodiscard]] bool ReadMsgES(Span<std::byte> msg);

private:
void GenerateEphemeralKey(CKey& key) noexcept;
/** Generate ephemeral key, sets set m_ephemeral_key and m_our_ephemeral_ellswift_pk */
void GenerateEphemeralKey() noexcept;
};

class Sv2Cipher
Expand Down
12 changes: 6 additions & 6 deletions src/common/sv2_transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void Sv2Transport::StartSendingHandshake() noexcept
Assume(m_send_state == SendState::HANDSHAKE_STEP_1);
Assume(m_send_buffer.empty());

m_send_buffer.resize(KEY_SIZE);
m_send_buffer.resize(ELLSWIFT_KEY_SIZE);
Assume(m_cipher.m_handshake_state);
m_cipher.m_handshake_state->WriteMsgEphemeralPK(MakeWritableByteSpan(m_send_buffer));

Expand Down Expand Up @@ -227,7 +227,7 @@ bool Sv2Transport::ReceivedBytes(Span<const uint8_t>& msg_bytes) noexcept
if (m_recv_buffer.size() + std::min(msg_bytes.size(), max_read) > m_recv_buffer.capacity()) {
switch (m_recv_state) {
case RecvState::HANDSHAKE_STEP_1:
m_recv_buffer.reserve(KEY_SIZE);
m_recv_buffer.reserve(ELLSWIFT_KEY_SIZE);
break;
case RecvState::HANDSHAKE_STEP_2:
m_recv_buffer.reserve(INITIATOR_EXPECTED_HANDSHAKE_MESSAGE_LENGTH);
Expand Down Expand Up @@ -286,9 +286,9 @@ bool Sv2Transport::ProcessReceivedEphemeralKeyBytes() noexcept
AssertLockHeld(m_recv_mutex);
AssertLockNotHeld(m_send_mutex);
Assume(m_recv_state == RecvState::HANDSHAKE_STEP_1);
Assume(m_recv_buffer.size() <= KEY_SIZE);
Assume(m_recv_buffer.size() <= ELLSWIFT_KEY_SIZE);

if (m_recv_buffer.size() == KEY_SIZE) {
if (m_recv_buffer.size() == ELLSWIFT_KEY_SIZE) {
// Other side's key has been fully received, and can now be Diffie-Hellman
// combined with our key. This is act 1 of the Noise Protocol handshake.
// TODO handle failure
Expand Down Expand Up @@ -342,8 +342,8 @@ size_t Sv2Transport::GetMaxBytesToProcess() noexcept
switch (m_recv_state) {
case RecvState::HANDSHAKE_STEP_1:
// In this state, we only allow the 32-byte key into the receive buffer.
Assume(m_recv_buffer.size() <= KEY_SIZE);
return KEY_SIZE - m_recv_buffer.size();
Assume(m_recv_buffer.size() <= ELLSWIFT_KEY_SIZE);
return ELLSWIFT_KEY_SIZE - m_recv_buffer.size();
case RecvState::HANDSHAKE_STEP_2:
// In this state, we only allow the handshake reply into the receive buffer.
Assume(m_recv_buffer.size() <= INITIATOR_EXPECTED_HANDSHAKE_MESSAGE_LENGTH);
Expand Down
Loading

0 comments on commit a1f476f

Please sign in to comment.