diff --git a/include/noise/protocol/handshakestate.h b/include/noise/protocol/handshakestate.h index a49c95a4..9df3ff18 100644 --- a/include/noise/protocol/handshakestate.h +++ b/include/noise/protocol/handshakestate.h @@ -51,6 +51,7 @@ NoiseDHState *noise_handshakestate_get_fixed_ephemeral_dh (NoiseHandshakeState *state); NoiseDHState *noise_handshakestate_get_fixed_hybrid_dh (NoiseHandshakeState *state); +int noise_handshakestate_needs_pre_shared_key(const NoiseHandshakeState *state); int noise_handshakestate_has_pre_shared_key(const NoiseHandshakeState *state); int noise_handshakestate_set_pre_shared_key (NoiseHandshakeState *state, const uint8_t *key, size_t key_len); diff --git a/src/protocol/handshakestate.c b/src/protocol/handshakestate.c index 0c370b83..5649ba24 100644 --- a/src/protocol/handshakestate.c +++ b/src/protocol/handshakestate.c @@ -123,6 +123,16 @@ static int noise_handshakestate_new num_modifiers = 0; while (num_modifiers < NOISE_MAX_MODIFIER_IDS && symmetric->id.modifier_ids[num_modifiers] != 0) { + switch(symmetric->id.modifier_ids[num_modifiers]) { + case NOISE_MODIFIER_PSK0: + case NOISE_MODIFIER_PSK1: + case NOISE_MODIFIER_PSK2: + case NOISE_MODIFIER_PSK3: + extra_reqs |= NOISE_REQ_PSK; + break; + default: + break; + } ++num_modifiers; } if (noise_pattern_expand(tokens, symmetric->id.pattern_id, @@ -533,6 +543,25 @@ NoiseDHState *noise_handshakestate_get_fixed_hybrid_dh return state->dh_fixed_hybrid; } +/** + * \brief Determine if a HandshakeState object requires a pre shared key. + * + * \param state The HandshakeState object. + * + * \return Returns 1 if \a state requires a pre shared key, zero if the + * pre shared key has already been supplied or it is not required. + * + * \sa noise_handshakestate_set_pre_shared_key(), + * noise_handshakestate_has_pre_shared_key() + */ +int noise_handshakestate_needs_pre_shared_key(const NoiseHandshakeState *state) +{ + if (!state || state->pre_shared_key_len) + return 0; + else + return (state->requirements & NOISE_REQ_PSK) != 0; +} + /** * \brief Determine if a HandshakeState object has already been configured * with a pre shared key. @@ -541,7 +570,8 @@ NoiseDHState *noise_handshakestate_get_fixed_hybrid_dh * * \return Returns 1 if \a state has a pre shared key, zero if not. * - * \sa noise_handshakestate_set_pre_shared_key() + * \sa noise_handshakestate_set_pre_shared_key(), + * noise_handshakestate_needs_pre_shared_key() */ int noise_handshakestate_has_pre_shared_key(const NoiseHandshakeState *state) { @@ -569,6 +599,7 @@ int noise_handshakestate_has_pre_shared_key(const NoiseHandshakeState *state) * then the value will be ignored. * * \sa noise_handshakestate_start(), noise_handshakestate_set_prologue(), + * noise_handshakestate_needs_pre_shared_key(), * noise_handshakestate_has_pre_shared_key(), * noise_handshakestate_set_pre_shared_key_hook() */ @@ -1213,6 +1244,9 @@ static int noise_handshakestate_write return NOISE_ERROR_INVALID_LENGTH; memcpy(rest.data, state->dh_local_ephemeral->public_key, len); noise_symmetricstate_mix_hash(state->symmetric, rest.data, len); + if (state->requirements & NOISE_REQ_PSK) { + noise_symmetricstate_mix_key(state->symmetric, rest.data, len); + } rest.size += len; break; case NOISE_TOKEN_S: @@ -1476,6 +1510,12 @@ static int noise_handshakestate_read (state->symmetric, msg.data, len); if (err != NOISE_ERROR_NONE) break; + if (state->requirements & NOISE_REQ_PSK) { + err = noise_symmetricstate_mix_key + (state->symmetric, msg.data, len); + if (err != NOISE_ERROR_NONE) + break; + } err = noise_dhstate_set_public_key (state->dh_remote_ephemeral, msg.data, len); if (err != NOISE_ERROR_NONE) diff --git a/src/protocol/internal.h b/src/protocol/internal.h index cb43ecb2..b96f59ae 100644 --- a/src/protocol/internal.h +++ b/src/protocol/internal.h @@ -652,14 +652,16 @@ struct NoiseHandshakeState_s #define NOISE_REQ_LOCAL_REQUIRED (1 << 0) /** Remote public key is required for the handshake */ #define NOISE_REQ_REMOTE_REQUIRED (1 << 1) +/** Pre-shared key has not been provided yet */ +#define NOISE_REQ_PSK (1 << 2) /** Emphemeral key for fallback pre-message has been provided */ -#define NOISE_REQ_FALLBACK_PREMSG (1 << 2) +#define NOISE_REQ_FALLBACK_PREMSG (1 << 3) /** Local public key is part of the pre-message */ -#define NOISE_REQ_LOCAL_PREMSG (1 << 3) +#define NOISE_REQ_LOCAL_PREMSG (1 << 4) /** Remote public key is part of the pre-message */ -#define NOISE_REQ_REMOTE_PREMSG (1 << 4) +#define NOISE_REQ_REMOTE_PREMSG (1 << 5) /** Fallback is possible from this pattern (two-way, ends in "K") */ -#define NOISE_REQ_FALLBACK_POSSIBLE (1 << 5) +#define NOISE_REQ_FALLBACK_POSSIBLE (1 << 6) void noise_rand_bytes(void *bytes, size_t size); diff --git a/src/protocol/symmetricstate.c b/src/protocol/symmetricstate.c index 0e25c8d7..7124016d 100644 --- a/src/protocol/symmetricstate.c +++ b/src/protocol/symmetricstate.c @@ -281,6 +281,10 @@ int noise_symmetricstate_mix_key (state->hash, state->ck, hash_len, input, size, state->ck, hash_len, temp_k, key_len); + /* Truncate temp_k */ + if (hash_len == 64 && key_len > 32) + key_len = 32; + /* Change the cipher key, or set it for the first time */ noise_cipherstate_init_key(state->cipher, temp_k, key_len); noise_clean(temp_k, sizeof(temp_k)); @@ -360,6 +364,10 @@ int noise_symmetricstate_mix_key_and_hash noise_hashstate_hash_two (state->hash, state->h, hash_len, temp_h, hash_len, state->h, hash_len); + /* Truncate temp_k */ + if (hash_len == 64 && key_len > 32) + key_len = 32; + /* Change the cipher key, or set it for the first time */ noise_cipherstate_init_key(state->cipher, temp_k, key_len); noise_clean(temp_h, sizeof(temp_h));