Skip to content

Commit

Permalink
feat: Server name getter for client hello (#4396)
Browse files Browse the repository at this point in the history
  • Loading branch information
maddeleine authored Feb 16, 2024
1 parent 629c60a commit e16e151
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 25 deletions.
20 changes: 20 additions & 0 deletions api/s2n.h
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,26 @@ S2N_API extern int s2n_client_hello_get_legacy_protocol_version(struct s2n_clien
S2N_API extern int s2n_client_hello_get_supported_groups(struct s2n_client_hello *ch, uint16_t *groups,
uint16_t groups_count_max, uint16_t *groups_count);

/**
* Gets the length of the first server name in a Client Hello.
*
* @param ch A pointer to the ClientHello
* @param length A pointer which will be populated with the length of the server name
*/
S2N_API extern int s2n_client_hello_get_server_name_length(struct s2n_client_hello *ch, uint16_t *length);

/**
* Gets the first server name in a Client Hello.
*
* Use `s2n_client_hello_get_server_name_length()` to get the amount of memory needed for the buffer.
*
* @param ch A pointer to the ClientHello
* @param server_name A pointer to the memory which will be populated with the server name
* @param length The maximum amount of data that can be written to `server_name`
* @param out_length A pointer which will be populated with the size of the server name
*/
S2N_API extern int s2n_client_hello_get_server_name(struct s2n_client_hello *ch, uint8_t *server_name, uint16_t length, uint16_t *out_length);

/**
* Sets the file descriptor for a s2n connection.
*
Expand Down
21 changes: 21 additions & 0 deletions bindings/rust/s2n-tls/src/client_hello.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,27 @@ impl ClientHello {
Ok(session_id)
}

fn server_name(&self) -> Result<Vec<u8>, Error> {
let mut server_name_length = 0;
unsafe {
s2n_client_hello_get_server_name_length(self.deref_mut_ptr(), &mut server_name_length)
.into_result()?;
}

let mut server_name = vec![0; server_name_length as usize];
let mut out_length = 0;
unsafe {
s2n_client_hello_get_server_name(
self.deref_mut_ptr(),
server_name.as_mut_ptr(),
server_name_length,
&mut out_length,
)
.into_result()?;
}
Ok(server_name)
}

fn raw_message(&self) -> Result<Vec<u8>, Error> {
let message_length =
unsafe { s2n_client_hello_get_raw_message_length(self.deref_mut_ptr()).into_result()? };
Expand Down
70 changes: 66 additions & 4 deletions tests/unit/s2n_client_hello_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "s2n_test.h"
#include "testlib/s2n_sslv2_client_hello.h"
#include "testlib/s2n_testlib.h"
#include "tls/s2n_client_hello.c"
#include "tls/s2n_connection.h"
#include "tls/s2n_handshake.h"
#include "tls/s2n_quic_support.h"
Expand All @@ -46,6 +45,8 @@
#define COMPRESSION_METHODS_LEN 0x05

int s2n_parse_client_hello(struct s2n_connection *conn);
S2N_RESULT s2n_client_hello_get_raw_extension(uint16_t extension_iana,
struct s2n_blob *raw_extensions, struct s2n_blob *extension);

int main(int argc, char **argv)
{
Expand Down Expand Up @@ -1849,7 +1850,7 @@ int main(int argc, char **argv)

EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_legacy_record_version(NULL, &out), S2N_ERR_NULL);
EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_legacy_record_version(&client_hello, NULL), S2N_ERR_NULL);
}
};

/* Retrieves record version */
{
Expand All @@ -1859,8 +1860,69 @@ int main(int argc, char **argv)
client_hello.record_version_recorded = 1;
EXPECT_SUCCESS(s2n_client_hello_get_legacy_record_version(&client_hello, &out));
EXPECT_EQUAL(out, S2N_TLS12);
}
}
};
};

/* s2n_client_hello_get_server_name() */
{
/* Safety */
{
struct s2n_client_hello ch = { 0 };
uint16_t length = 0;
EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_server_name_length(NULL, &length), S2N_ERR_NULL);
EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_server_name_length(&ch, NULL), S2N_ERR_NULL);

uint8_t buffer = 0;
uint16_t out_length = 0;
EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_server_name(NULL, &buffer, 0, &out_length), S2N_ERR_NULL);
EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_server_name(&ch, NULL, 0, &out_length), S2N_ERR_NULL);
EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_server_name(&ch, &buffer, 0, NULL), S2N_ERR_NULL);
};

/* Retrieves the first entry in the server_name extension */
{
DEFER_CLEANUP(struct s2n_connection *client_conn = s2n_connection_new(S2N_CLIENT),
s2n_connection_ptr_free);
DEFER_CLEANUP(struct s2n_connection *server_conn = s2n_connection_new(S2N_SERVER),
s2n_connection_ptr_free);
EXPECT_NOT_NULL(server_conn);
EXPECT_NOT_NULL(client_conn);

DEFER_CLEANUP(struct s2n_config *config = s2n_config_new(),
s2n_config_ptr_free);
EXPECT_NOT_NULL(config);

EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key));
EXPECT_SUCCESS(s2n_connection_set_config(client_conn, config));
EXPECT_SUCCESS(s2n_connection_set_config(server_conn, config));

const char *test_server_name = "test server name!";
EXPECT_SUCCESS(s2n_set_server_name(client_conn, test_server_name));

EXPECT_SUCCESS(s2n_client_hello_send(client_conn));
EXPECT_SUCCESS(s2n_stuffer_copy(&client_conn->handshake.io, &server_conn->handshake.io,
s2n_stuffer_data_available(&client_conn->handshake.io)));
EXPECT_SUCCESS(s2n_client_hello_recv(server_conn));

struct s2n_client_hello *client_hello = s2n_connection_get_client_hello(server_conn);
EXPECT_NOT_NULL(client_hello);

uint16_t length = 0;
EXPECT_SUCCESS(s2n_client_hello_get_server_name_length(client_hello, &length));
EXPECT_EQUAL(strlen(test_server_name), length);
uint8_t buffer[20] = { 0 };
uint16_t out_length = 0;
EXPECT_SUCCESS(s2n_client_hello_get_server_name(client_hello, buffer, sizeof(buffer), &out_length));
EXPECT_EQUAL(length, out_length);

EXPECT_BYTEARRAY_EQUAL(buffer, test_server_name, out_length);

/* Check error occurs if buffer is too small to hold server name */
uint8_t small_buf[2] = { 0 };
out_length = 0;
EXPECT_FAILURE_WITH_ERRNO(s2n_client_hello_get_server_name(client_hello, small_buf, sizeof(small_buf), &out_length), S2N_ERR_SAFETY);
};
};

EXPECT_SUCCESS(s2n_cert_chain_and_key_free(chain_and_key));
EXPECT_SUCCESS(s2n_cert_chain_and_key_free(ecdsa_chain_and_key));
Expand Down
40 changes: 20 additions & 20 deletions tls/extensions/s2n_client_server_name.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,26 +58,28 @@ static int s2n_client_server_name_send(struct s2n_connection *conn, struct s2n_s
return S2N_SUCCESS;
}

/* Read the extension up to the first item in ServerNameList. Store the first entry's length in server_name_len.
* For now s2n ignores all subsequent items in ServerNameList.
/* Read the extension up to the first item in ServerNameList. Instantiates the server_name blob to
* point to the first entry. For now s2n ignores all subsequent items in ServerNameList.
*/
static int s2n_client_server_name_check(struct s2n_connection *conn, struct s2n_stuffer *extension, uint16_t *server_name_len)
S2N_RESULT s2n_client_server_name_parse(struct s2n_stuffer *extension, struct s2n_blob *server_name)
{
POSIX_ENSURE_REF(conn);
uint16_t list_size = 0;
RESULT_GUARD_POSIX(s2n_stuffer_read_uint16(extension, &list_size));
RESULT_ENSURE_LTE(list_size, s2n_stuffer_data_available(extension));

uint16_t size_of_all;
POSIX_GUARD(s2n_stuffer_read_uint16(extension, &size_of_all));
POSIX_ENSURE_LTE(size_of_all, s2n_stuffer_data_available(extension));
uint8_t server_name_type = 0;
RESULT_GUARD_POSIX(s2n_stuffer_read_uint8(extension, &server_name_type));
RESULT_ENSURE_EQ(server_name_type, S2N_NAME_TYPE_HOST_NAME);

uint8_t server_name_type;
POSIX_GUARD(s2n_stuffer_read_uint8(extension, &server_name_type));
POSIX_ENSURE_EQ(server_name_type, S2N_NAME_TYPE_HOST_NAME);
uint16_t length = 0;
RESULT_GUARD_POSIX(s2n_stuffer_read_uint16(extension, &length));
RESULT_ENSURE_LTE(length, s2n_stuffer_data_available(extension));

POSIX_GUARD(s2n_stuffer_read_uint16(extension, server_name_len));
POSIX_ENSURE_LT(*server_name_len, sizeof(conn->server_name));
POSIX_ENSURE_LTE(*server_name_len, s2n_stuffer_data_available(extension));
uint8_t *data = s2n_stuffer_raw_read(extension, length);
RESULT_ENSURE_REF(data);
RESULT_GUARD_POSIX(s2n_blob_init(server_name, data, length));

return S2N_SUCCESS;
return S2N_RESULT_OK;
}

static int s2n_client_server_name_recv(struct s2n_connection *conn, struct s2n_stuffer *extension)
Expand All @@ -89,15 +91,13 @@ static int s2n_client_server_name_recv(struct s2n_connection *conn, struct s2n_s
return S2N_SUCCESS;
}

/* Ignore if malformed. We just won't use the server name. */
uint16_t server_name_len;
if (s2n_client_server_name_check(conn, extension, &server_name_len) != S2N_SUCCESS) {
/* Ignore if malformed or we don't have enough space to store it. We just won't use the server name. */
struct s2n_blob server_name = { 0 };
if (!s2n_result_is_ok(s2n_client_server_name_parse(extension, &server_name)) || server_name.size > S2N_MAX_SERVER_NAME) {
return S2N_SUCCESS;
}

uint8_t *server_name;
POSIX_ENSURE_REF(server_name = s2n_stuffer_raw_read(extension, server_name_len));
POSIX_CHECKED_MEMCPY(conn->server_name, server_name, server_name_len);
POSIX_CHECKED_MEMCPY(conn->server_name, server_name.data, server_name.size);

return S2N_SUCCESS;
}
1 change: 1 addition & 0 deletions tls/extensions/s2n_client_server_name.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
#include "tls/s2n_connection.h"

extern const s2n_extension_type s2n_client_server_name_extension;
S2N_RESULT s2n_client_server_name_parse(struct s2n_stuffer *extension, struct s2n_blob *server_name);
47 changes: 46 additions & 1 deletion tls/s2n_client_hello.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "crypto/s2n_rsa_signing.h"
#include "error/s2n_errno.h"
#include "stuffer/s2n_stuffer.h"
#include "tls/extensions/s2n_client_server_name.h"
#include "tls/extensions/s2n_client_supported_groups.h"
#include "tls/extensions/s2n_extension_list.h"
#include "tls/extensions/s2n_server_key_share.h"
Expand Down Expand Up @@ -962,7 +963,7 @@ int s2n_client_hello_get_legacy_record_version(struct s2n_client_hello *ch, uint
return S2N_SUCCESS;
}

static S2N_RESULT s2n_client_hello_get_raw_extension(uint16_t extension_iana,
S2N_RESULT s2n_client_hello_get_raw_extension(uint16_t extension_iana,
struct s2n_blob *raw_extensions, struct s2n_blob *extension)
{
RESULT_ENSURE_REF(raw_extensions);
Expand Down Expand Up @@ -1046,3 +1047,47 @@ int s2n_client_hello_get_supported_groups(struct s2n_client_hello *ch, uint16_t

return S2N_SUCCESS;
}

int s2n_client_hello_get_server_name_length(struct s2n_client_hello *ch, uint16_t *length)
{
POSIX_ENSURE_REF(ch);
POSIX_ENSURE_REF(length);
*length = 0;

s2n_parsed_extension *server_name_extension = NULL;
POSIX_GUARD(s2n_client_hello_get_parsed_extension(S2N_EXTENSION_SERVER_NAME, &ch->extensions, &server_name_extension));
POSIX_ENSURE_REF(server_name_extension);

struct s2n_stuffer extension_stuffer = { 0 };
POSIX_GUARD(s2n_stuffer_init_written(&extension_stuffer, &server_name_extension->extension));

struct s2n_blob blob = { 0 };
POSIX_GUARD_RESULT(s2n_client_server_name_parse(&extension_stuffer, &blob));
*length = blob.size;

return S2N_SUCCESS;
}

int s2n_client_hello_get_server_name(struct s2n_client_hello *ch, uint8_t *server_name, uint16_t length, uint16_t *out_length)
{
POSIX_ENSURE_REF(out_length);
POSIX_ENSURE_REF(ch);
POSIX_ENSURE_REF(server_name);
*out_length = 0;

s2n_parsed_extension *server_name_extension = NULL;
POSIX_GUARD(s2n_client_hello_get_parsed_extension(S2N_EXTENSION_SERVER_NAME, &ch->extensions, &server_name_extension));
POSIX_ENSURE_REF(server_name_extension);

struct s2n_stuffer extension_stuffer = { 0 };
POSIX_GUARD(s2n_stuffer_init_written(&extension_stuffer, &server_name_extension->extension));

struct s2n_blob blob = { 0 };
POSIX_GUARD_RESULT(s2n_client_server_name_parse(&extension_stuffer, &blob));
POSIX_ENSURE_LTE(blob.size, length);
POSIX_CHECKED_MEMCPY(server_name, blob.data, blob.size);

*out_length = blob.size;

return S2N_SUCCESS;
}

0 comments on commit e16e151

Please sign in to comment.