From e16e151b44d3d5549803d776a9ed46d4b7d1cb04 Mon Sep 17 00:00:00 2001 From: maddeleine <59030281+maddeleine@users.noreply.github.com> Date: Thu, 15 Feb 2024 16:45:33 -0800 Subject: [PATCH] feat: Server name getter for client hello (#4396) --- api/s2n.h | 20 +++++++ bindings/rust/s2n-tls/src/client_hello.rs | 21 +++++++ tests/unit/s2n_client_hello_test.c | 70 +++++++++++++++++++++-- tls/extensions/s2n_client_server_name.c | 40 ++++++------- tls/extensions/s2n_client_server_name.h | 1 + tls/s2n_client_hello.c | 47 ++++++++++++++- 6 files changed, 174 insertions(+), 25 deletions(-) diff --git a/api/s2n.h b/api/s2n.h index 5cdb68c8c76..5c375f57491 100644 --- a/api/s2n.h +++ b/api/s2n.h @@ -1625,6 +1625,26 @@ S2N_API extern int s2n_client_hello_get_legacy_protocol_version(struct s2n_clien S2N_API extern int s2n_client_hello_get_supported_groups(struct s2n_client_hello *ch, uint16_t *groups, uint16_t groups_count_max, uint16_t *groups_count); +/** + * Gets the length of the first server name in a Client Hello. + * + * @param ch A pointer to the ClientHello + * @param length A pointer which will be populated with the length of the server name + */ +S2N_API extern int s2n_client_hello_get_server_name_length(struct s2n_client_hello *ch, uint16_t *length); + +/** + * Gets the first server name in a Client Hello. + * + * Use `s2n_client_hello_get_server_name_length()` to get the amount of memory needed for the buffer. + * + * @param ch A pointer to the ClientHello + * @param server_name A pointer to the memory which will be populated with the server name + * @param length The maximum amount of data that can be written to `server_name` + * @param out_length A pointer which will be populated with the size of the server name + */ +S2N_API extern int s2n_client_hello_get_server_name(struct s2n_client_hello *ch, uint8_t *server_name, uint16_t length, uint16_t *out_length); + /** * Sets the file descriptor for a s2n connection. * diff --git a/bindings/rust/s2n-tls/src/client_hello.rs b/bindings/rust/s2n-tls/src/client_hello.rs index 4391190d736..dab6e2444db 100644 --- a/bindings/rust/s2n-tls/src/client_hello.rs +++ b/bindings/rust/s2n-tls/src/client_hello.rs @@ -168,6 +168,27 @@ impl ClientHello { Ok(session_id) } + fn server_name(&self) -> Result, Error> { + let mut server_name_length = 0; + unsafe { + s2n_client_hello_get_server_name_length(self.deref_mut_ptr(), &mut server_name_length) + .into_result()?; + } + + let mut server_name = vec![0; server_name_length as usize]; + let mut out_length = 0; + unsafe { + s2n_client_hello_get_server_name( + self.deref_mut_ptr(), + server_name.as_mut_ptr(), + server_name_length, + &mut out_length, + ) + .into_result()?; + } + Ok(server_name) + } + fn raw_message(&self) -> Result, Error> { let message_length = unsafe { s2n_client_hello_get_raw_message_length(self.deref_mut_ptr()).into_result()? }; diff --git a/tests/unit/s2n_client_hello_test.c b/tests/unit/s2n_client_hello_test.c index 1ca4aa16233..89ff7a4d988 100644 --- a/tests/unit/s2n_client_hello_test.c +++ b/tests/unit/s2n_client_hello_test.c @@ -25,7 +25,6 @@ #include "s2n_test.h" #include "testlib/s2n_sslv2_client_hello.h" #include "testlib/s2n_testlib.h" -#include "tls/s2n_client_hello.c" #include "tls/s2n_connection.h" #include "tls/s2n_handshake.h" #include "tls/s2n_quic_support.h" @@ -46,6 +45,8 @@ #define COMPRESSION_METHODS_LEN 0x05 int s2n_parse_client_hello(struct s2n_connection *conn); +S2N_RESULT s2n_client_hello_get_raw_extension(uint16_t extension_iana, + struct s2n_blob *raw_extensions, struct s2n_blob *extension); int main(int argc, char **argv) { @@ -1849,7 +1850,7 @@ int main(int argc, char **argv) EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_legacy_record_version(NULL, &out), S2N_ERR_NULL); EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_legacy_record_version(&client_hello, NULL), S2N_ERR_NULL); - } + }; /* Retrieves record version */ { @@ -1859,8 +1860,69 @@ int main(int argc, char **argv) client_hello.record_version_recorded = 1; EXPECT_SUCCESS(s2n_client_hello_get_legacy_record_version(&client_hello, &out)); EXPECT_EQUAL(out, S2N_TLS12); - } - } + }; + }; + + /* s2n_client_hello_get_server_name() */ + { + /* Safety */ + { + struct s2n_client_hello ch = { 0 }; + uint16_t length = 0; + EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_server_name_length(NULL, &length), S2N_ERR_NULL); + EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_server_name_length(&ch, NULL), S2N_ERR_NULL); + + uint8_t buffer = 0; + uint16_t out_length = 0; + EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_server_name(NULL, &buffer, 0, &out_length), S2N_ERR_NULL); + EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_server_name(&ch, NULL, 0, &out_length), S2N_ERR_NULL); + EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_server_name(&ch, &buffer, 0, NULL), S2N_ERR_NULL); + }; + + /* Retrieves the first entry in the server_name extension */ + { + DEFER_CLEANUP(struct s2n_connection *client_conn = s2n_connection_new(S2N_CLIENT), + s2n_connection_ptr_free); + DEFER_CLEANUP(struct s2n_connection *server_conn = s2n_connection_new(S2N_SERVER), + s2n_connection_ptr_free); + EXPECT_NOT_NULL(server_conn); + EXPECT_NOT_NULL(client_conn); + + DEFER_CLEANUP(struct s2n_config *config = s2n_config_new(), + s2n_config_ptr_free); + EXPECT_NOT_NULL(config); + + EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key)); + EXPECT_SUCCESS(s2n_connection_set_config(client_conn, config)); + EXPECT_SUCCESS(s2n_connection_set_config(server_conn, config)); + + const char *test_server_name = "test server name!"; + EXPECT_SUCCESS(s2n_set_server_name(client_conn, test_server_name)); + + EXPECT_SUCCESS(s2n_client_hello_send(client_conn)); + EXPECT_SUCCESS(s2n_stuffer_copy(&client_conn->handshake.io, &server_conn->handshake.io, + s2n_stuffer_data_available(&client_conn->handshake.io))); + EXPECT_SUCCESS(s2n_client_hello_recv(server_conn)); + + struct s2n_client_hello *client_hello = s2n_connection_get_client_hello(server_conn); + EXPECT_NOT_NULL(client_hello); + + uint16_t length = 0; + EXPECT_SUCCESS(s2n_client_hello_get_server_name_length(client_hello, &length)); + EXPECT_EQUAL(strlen(test_server_name), length); + uint8_t buffer[20] = { 0 }; + uint16_t out_length = 0; + EXPECT_SUCCESS(s2n_client_hello_get_server_name(client_hello, buffer, sizeof(buffer), &out_length)); + EXPECT_EQUAL(length, out_length); + + EXPECT_BYTEARRAY_EQUAL(buffer, test_server_name, out_length); + + /* Check error occurs if buffer is too small to hold server name */ + uint8_t small_buf[2] = { 0 }; + out_length = 0; + EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_server_name(client_hello, small_buf, sizeof(small_buf), &out_length), S2N_ERR_SAFETY); + }; + }; EXPECT_SUCCESS(s2n_cert_chain_and_key_free(chain_and_key)); EXPECT_SUCCESS(s2n_cert_chain_and_key_free(ecdsa_chain_and_key)); diff --git a/tls/extensions/s2n_client_server_name.c b/tls/extensions/s2n_client_server_name.c index f66c0e6512c..7ec68ada3d3 100644 --- a/tls/extensions/s2n_client_server_name.c +++ b/tls/extensions/s2n_client_server_name.c @@ -58,26 +58,28 @@ static int s2n_client_server_name_send(struct s2n_connection *conn, struct s2n_s return S2N_SUCCESS; } -/* Read the extension up to the first item in ServerNameList. Store the first entry's length in server_name_len. - * For now s2n ignores all subsequent items in ServerNameList. +/* Read the extension up to the first item in ServerNameList. Instantiates the server_name blob to + * point to the first entry. For now s2n ignores all subsequent items in ServerNameList. */ -static int s2n_client_server_name_check(struct s2n_connection *conn, struct s2n_stuffer *extension, uint16_t *server_name_len) +S2N_RESULT s2n_client_server_name_parse(struct s2n_stuffer *extension, struct s2n_blob *server_name) { - POSIX_ENSURE_REF(conn); + uint16_t list_size = 0; + RESULT_GUARD_POSIX(s2n_stuffer_read_uint16(extension, &list_size)); + RESULT_ENSURE_LTE(list_size, s2n_stuffer_data_available(extension)); - uint16_t size_of_all; - POSIX_GUARD(s2n_stuffer_read_uint16(extension, &size_of_all)); - POSIX_ENSURE_LTE(size_of_all, s2n_stuffer_data_available(extension)); + uint8_t server_name_type = 0; + RESULT_GUARD_POSIX(s2n_stuffer_read_uint8(extension, &server_name_type)); + RESULT_ENSURE_EQ(server_name_type, S2N_NAME_TYPE_HOST_NAME); - uint8_t server_name_type; - POSIX_GUARD(s2n_stuffer_read_uint8(extension, &server_name_type)); - POSIX_ENSURE_EQ(server_name_type, S2N_NAME_TYPE_HOST_NAME); + uint16_t length = 0; + RESULT_GUARD_POSIX(s2n_stuffer_read_uint16(extension, &length)); + RESULT_ENSURE_LTE(length, s2n_stuffer_data_available(extension)); - POSIX_GUARD(s2n_stuffer_read_uint16(extension, server_name_len)); - POSIX_ENSURE_LT(*server_name_len, sizeof(conn->server_name)); - POSIX_ENSURE_LTE(*server_name_len, s2n_stuffer_data_available(extension)); + uint8_t *data = s2n_stuffer_raw_read(extension, length); + RESULT_ENSURE_REF(data); + RESULT_GUARD_POSIX(s2n_blob_init(server_name, data, length)); - return S2N_SUCCESS; + return S2N_RESULT_OK; } static int s2n_client_server_name_recv(struct s2n_connection *conn, struct s2n_stuffer *extension) @@ -89,15 +91,13 @@ static int s2n_client_server_name_recv(struct s2n_connection *conn, struct s2n_s return S2N_SUCCESS; } - /* Ignore if malformed. We just won't use the server name. */ - uint16_t server_name_len; - if (s2n_client_server_name_check(conn, extension, &server_name_len) != S2N_SUCCESS) { + /* Ignore if malformed or we don't have enough space to store it. We just won't use the server name. */ + struct s2n_blob server_name = { 0 }; + if (!s2n_result_is_ok(s2n_client_server_name_parse(extension, &server_name)) || server_name.size > S2N_MAX_SERVER_NAME) { return S2N_SUCCESS; } - uint8_t *server_name; - POSIX_ENSURE_REF(server_name = s2n_stuffer_raw_read(extension, server_name_len)); - POSIX_CHECKED_MEMCPY(conn->server_name, server_name, server_name_len); + POSIX_CHECKED_MEMCPY(conn->server_name, server_name.data, server_name.size); return S2N_SUCCESS; } diff --git a/tls/extensions/s2n_client_server_name.h b/tls/extensions/s2n_client_server_name.h index 8eb868cb6b0..06aeb14f755 100644 --- a/tls/extensions/s2n_client_server_name.h +++ b/tls/extensions/s2n_client_server_name.h @@ -20,3 +20,4 @@ #include "tls/s2n_connection.h" extern const s2n_extension_type s2n_client_server_name_extension; +S2N_RESULT s2n_client_server_name_parse(struct s2n_stuffer *extension, struct s2n_blob *server_name); diff --git a/tls/s2n_client_hello.c b/tls/s2n_client_hello.c index 1bd72c3c493..76da265855b 100644 --- a/tls/s2n_client_hello.c +++ b/tls/s2n_client_hello.c @@ -26,6 +26,7 @@ #include "crypto/s2n_rsa_signing.h" #include "error/s2n_errno.h" #include "stuffer/s2n_stuffer.h" +#include "tls/extensions/s2n_client_server_name.h" #include "tls/extensions/s2n_client_supported_groups.h" #include "tls/extensions/s2n_extension_list.h" #include "tls/extensions/s2n_server_key_share.h" @@ -962,7 +963,7 @@ int s2n_client_hello_get_legacy_record_version(struct s2n_client_hello *ch, uint return S2N_SUCCESS; } -static S2N_RESULT s2n_client_hello_get_raw_extension(uint16_t extension_iana, +S2N_RESULT s2n_client_hello_get_raw_extension(uint16_t extension_iana, struct s2n_blob *raw_extensions, struct s2n_blob *extension) { RESULT_ENSURE_REF(raw_extensions); @@ -1046,3 +1047,47 @@ int s2n_client_hello_get_supported_groups(struct s2n_client_hello *ch, uint16_t return S2N_SUCCESS; } + +int s2n_client_hello_get_server_name_length(struct s2n_client_hello *ch, uint16_t *length) +{ + POSIX_ENSURE_REF(ch); + POSIX_ENSURE_REF(length); + *length = 0; + + s2n_parsed_extension *server_name_extension = NULL; + POSIX_GUARD(s2n_client_hello_get_parsed_extension(S2N_EXTENSION_SERVER_NAME, &ch->extensions, &server_name_extension)); + POSIX_ENSURE_REF(server_name_extension); + + struct s2n_stuffer extension_stuffer = { 0 }; + POSIX_GUARD(s2n_stuffer_init_written(&extension_stuffer, &server_name_extension->extension)); + + struct s2n_blob blob = { 0 }; + POSIX_GUARD_RESULT(s2n_client_server_name_parse(&extension_stuffer, &blob)); + *length = blob.size; + + return S2N_SUCCESS; +} + +int s2n_client_hello_get_server_name(struct s2n_client_hello *ch, uint8_t *server_name, uint16_t length, uint16_t *out_length) +{ + POSIX_ENSURE_REF(out_length); + POSIX_ENSURE_REF(ch); + POSIX_ENSURE_REF(server_name); + *out_length = 0; + + s2n_parsed_extension *server_name_extension = NULL; + POSIX_GUARD(s2n_client_hello_get_parsed_extension(S2N_EXTENSION_SERVER_NAME, &ch->extensions, &server_name_extension)); + POSIX_ENSURE_REF(server_name_extension); + + struct s2n_stuffer extension_stuffer = { 0 }; + POSIX_GUARD(s2n_stuffer_init_written(&extension_stuffer, &server_name_extension->extension)); + + struct s2n_blob blob = { 0 }; + POSIX_GUARD_RESULT(s2n_client_server_name_parse(&extension_stuffer, &blob)); + POSIX_ENSURE_LTE(blob.size, length); + POSIX_CHECKED_MEMCPY(server_name, blob.data, blob.size); + + *out_length = blob.size; + + return S2N_SUCCESS; +}