diff --git a/tests/unit/s2n_psk_offered_test.c b/tests/unit/s2n_psk_offered_test.c index d2a6deadf65..c22b497da9c 100644 --- a/tests/unit/s2n_psk_offered_test.c +++ b/tests/unit/s2n_psk_offered_test.c @@ -330,9 +330,7 @@ int main(int argc, char **argv) uint8_t *data = NULL; uint16_t size = 0; - EXPECT_SUCCESS(s2n_offered_psk_get_identity(psk, &data, &size)); - EXPECT_EQUAL(size, 0); - EXPECT_EQUAL(data, NULL); + EXPECT_FAILURE_WITH_ERRNO(s2n_offered_psk_get_identity(psk, &data, &size), S2N_ERR_NULL); }; /* Valid identity */ diff --git a/tests/unit/s2n_psk_test.c b/tests/unit/s2n_psk_test.c index a6c18936dda..b9ca473987c 100644 --- a/tests/unit/s2n_psk_test.c +++ b/tests/unit/s2n_psk_test.c @@ -94,6 +94,7 @@ int main(int argc, char **argv) EXPECT_OK(s2n_psk_init(&psk, S2N_PSK_TYPE_EXTERNAL)); uint8_t test_value_1[] = TEST_VALUE_1; + uint8_t all_zero_value[5] = { 0 }; EXPECT_FAILURE_WITH_ERRNO(s2n_psk_set_secret(NULL, test_value_1, 1), S2N_ERR_NULL); @@ -101,6 +102,8 @@ int main(int argc, char **argv) S2N_ERR_NULL); EXPECT_FAILURE_WITH_ERRNO(s2n_psk_set_secret(&psk, test_value_1, 0), S2N_ERR_INVALID_ARGUMENT); + EXPECT_FAILURE_WITH_ERRNO(s2n_psk_set_secret(&psk, all_zero_value, s2n_array_len(all_zero_value)), + S2N_ERR_INVALID_ARGUMENT); EXPECT_SUCCESS(s2n_psk_set_secret(&psk, test_value_1, sizeof(test_value_1))); EXPECT_EQUAL(psk.secret.size, sizeof(TEST_VALUE_1)); @@ -794,22 +797,39 @@ int main(int argc, char **argv) /* Invalid PSK not added to connection */ { - struct s2n_connection *conn = NULL; - EXPECT_NOT_NULL(conn = s2n_connection_new(S2N_CLIENT)); - /* PSK is invalid because it has no identity */ - DEFER_CLEANUP(struct s2n_psk *invalid_psk = s2n_external_psk_new(), s2n_psk_free); - EXPECT_SUCCESS(s2n_psk_set_secret(invalid_psk, secret_0, sizeof(secret_0))); + { + DEFER_CLEANUP(struct s2n_connection *conn = s2n_connection_new(S2N_CLIENT), s2n_connection_ptr_free); + EXPECT_NOT_NULL(conn); - EXPECT_FAILURE_WITH_ERRNO(s2n_connection_append_psk(conn, invalid_psk), - S2N_ERR_INVALID_ARGUMENT); - EXPECT_EQUAL(conn->psk_params.psk_list.len, 0); + DEFER_CLEANUP(struct s2n_psk *invalid_psk = s2n_external_psk_new(), s2n_psk_free); + EXPECT_SUCCESS(s2n_psk_set_secret(invalid_psk, secret_0, sizeof(secret_0))); - /* Successful if identity added to PSK, making it valid */ - EXPECT_SUCCESS(s2n_psk_set_identity(invalid_psk, identity_0, sizeof(identity_0))); - EXPECT_SUCCESS(s2n_connection_append_psk(conn, invalid_psk)); + EXPECT_FAILURE_WITH_ERRNO(s2n_connection_append_psk(conn, invalid_psk), + S2N_ERR_INVALID_ARGUMENT); + EXPECT_EQUAL(conn->psk_params.psk_list.len, 0); - EXPECT_SUCCESS(s2n_connection_free(conn)); + /* Successful if identity added to PSK, making it valid */ + EXPECT_SUCCESS(s2n_psk_set_identity(invalid_psk, identity_0, sizeof(identity_0))); + EXPECT_SUCCESS(s2n_connection_append_psk(conn, invalid_psk)); + }; + + /* PSK is invalid because it has no secret */ + { + DEFER_CLEANUP(struct s2n_connection *conn = s2n_connection_new(S2N_CLIENT), s2n_connection_ptr_free); + EXPECT_NOT_NULL(conn); + + DEFER_CLEANUP(struct s2n_psk *invalid_psk = s2n_external_psk_new(), s2n_psk_free); + EXPECT_SUCCESS(s2n_psk_set_identity(invalid_psk, identity_0, sizeof(identity_0))); + + EXPECT_FAILURE_WITH_ERRNO(s2n_connection_append_psk(conn, invalid_psk), + S2N_ERR_INVALID_ARGUMENT); + EXPECT_EQUAL(conn->psk_params.psk_list.len, 0); + + /* Successful if secret added to PSK, making it valid */ + EXPECT_SUCCESS(s2n_psk_set_secret(invalid_psk, secret_0, sizeof(secret_0))); + EXPECT_SUCCESS(s2n_connection_append_psk(conn, invalid_psk)); + }; }; /* Huge PSK not added to client connection */ diff --git a/tls/s2n_psk.c b/tls/s2n_psk.c index b02ddc94c44..28abf69b07f 100644 --- a/tls/s2n_psk.c +++ b/tls/s2n_psk.c @@ -68,6 +68,15 @@ int s2n_psk_set_secret(struct s2n_psk *psk, const uint8_t *secret, uint16_t secr POSIX_ENSURE_REF(secret); POSIX_ENSURE(secret_size != 0, S2N_ERR_INVALID_ARGUMENT); + /* There are a number of application level errors that might result in an + * all-zero secret accidentally getting used. Error if that happens. + */ + bool secret_is_all_zero = true; + for (uint16_t i = 0; i < secret_size; i++) { + secret_is_all_zero = secret_is_all_zero && secret[i] == 0; + } + POSIX_ENSURE(!secret_is_all_zero, S2N_ERR_INVALID_ARGUMENT); + POSIX_GUARD(s2n_realloc(&psk->secret, secret_size)); POSIX_CHECKED_MEMCPY(psk->secret.data, secret, secret_size); @@ -363,6 +372,7 @@ int s2n_offered_psk_free(struct s2n_offered_psk **psk) int s2n_offered_psk_get_identity(struct s2n_offered_psk *psk, uint8_t **identity, uint16_t *size) { POSIX_ENSURE_REF(psk); + POSIX_ENSURE_REF(psk->identity.data); POSIX_ENSURE_REF(identity); POSIX_ENSURE_REF(size); *identity = psk->identity.data;