Skip to content

Commit e8a4c1d

Browse files
committed
extmod/modssl: Add SSLContext class.
This commit adds the SSLContext class to the ssl module, and retains the existing ssl.wrap_socket() function to maintain backwards compatibility. CPython deprecated the ssl.wrap_socket() function since CPython 3.7 and instead one should use ssl.SSLContext().wrap_socket(). This commit makes that possible. For the axtls implementation: - ssl.SSLContext is added, although it doesn't hold much state because axtls requires calling ssl_ctx_new() for each new socket - ssl.SSLContext.wrap_socket() is added - ssl.PROTOCOL_TLS_CLIENT and ssl.PROTOCOL_TLS_SERVER are added For the mbedtls implementation: - ssl.SSLContext is added, and holds most of the mbedtls state - ssl.verify_mode is added (getter and setter) - ssl.SSLContext.wrap_socket() is added - ssl.PROTOCOL_TLS_CLIENT and ssl.PROTOCOL_TLS_SERVER are added The signatures match CPython: - SSLContext(protocol) - SSLContext.wrap_socket(sock, *, server_side=False, do_handshake_on_connect=True, server_hostname=None) The existing ssl.wrap_socket() functions retain their existing signature. Signed-off-by: Damien George <[email protected]>
1 parent c2ea8b2 commit e8a4c1d

File tree

5 files changed

+384
-161
lines changed

5 files changed

+384
-161
lines changed

extmod/modssl_axtls.c

Lines changed: 127 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
* The MIT License (MIT)
55
*
66
* Copyright (c) 2015-2019 Paul Sokolovsky
7+
* Copyright (c) 2023 Damien P. George
78
*
89
* Permission is hereby granted, free of charge, to any person obtaining a copy
910
* of this software and associated documentation files (the "Software"), to deal
@@ -35,6 +36,17 @@
3536

3637
#include "ssl.h"
3738

39+
#define PROTOCOL_TLS_CLIENT (0)
40+
#define PROTOCOL_TLS_SERVER (1)
41+
42+
// This corresponds to an SSLContext object.
43+
typedef struct _mp_obj_ssl_context_t {
44+
mp_obj_base_t base;
45+
mp_obj_t key;
46+
mp_obj_t cert;
47+
} mp_obj_ssl_context_t;
48+
49+
// This corresponds to an SSLSocket object.
3850
typedef struct _mp_obj_ssl_socket_t {
3951
mp_obj_base_t base;
4052
mp_obj_t sock;
@@ -53,8 +65,15 @@ struct ssl_args {
5365
mp_arg_val_t do_handshake;
5466
};
5567

68+
STATIC const mp_obj_type_t ssl_context_type;
5669
STATIC const mp_obj_type_t ssl_socket_type;
5770

71+
STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock,
72+
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname);
73+
74+
/******************************************************************************/
75+
// Helper functions.
76+
5877
// Table of error strings corresponding to SSL_xxx error codes.
5978
STATIC const char *const ssl_error_tab1[] = {
6079
"NOT_OK",
@@ -116,8 +135,71 @@ STATIC NORETURN void ssl_raise_error(int err) {
116135
nlr_raise(mp_obj_exception_make_new(&mp_type_OSError, 2, 0, args));
117136
}
118137

138+
/******************************************************************************/
139+
// SSLContext type.
140+
141+
STATIC mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args, size_t n_kw, const mp_obj_t *args) {
142+
mp_arg_check_num(n_args, n_kw, 1, 1, false);
143+
144+
// The "protocol" argument is ignored in this implementation.
145+
146+
// Create SSLContext object.
147+
#if MICROPY_PY_SSL_FINALISER
148+
mp_obj_ssl_context_t *self = m_new_obj_with_finaliser(mp_obj_ssl_context_t);
149+
#else
150+
mp_obj_ssl_context_t *self = m_new_obj(mp_obj_ssl_context_t);
151+
#endif
152+
self->base.type = type_in;
153+
self->key = mp_const_none;
154+
self->cert = mp_const_none;
155+
156+
return MP_OBJ_FROM_PTR(self);
157+
}
158+
159+
STATIC void ssl_context_load_key(mp_obj_ssl_context_t *self, mp_obj_t key_obj, mp_obj_t cert_obj) {
160+
self->key = key_obj;
161+
self->cert = cert_obj;
162+
}
163+
164+
STATIC mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
165+
enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname };
166+
static const mp_arg_t allowed_args[] = {
167+
{ MP_QSTR_server_side, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = false} },
168+
{ MP_QSTR_do_handshake_on_connect, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} },
169+
{ MP_QSTR_server_hostname, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
170+
};
171+
172+
// Parse arguments.
173+
mp_obj_ssl_context_t *self = MP_OBJ_TO_PTR(pos_args[0]);
174+
mp_obj_t sock = pos_args[1];
175+
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
176+
mp_arg_parse_all(n_args - 2, pos_args + 2, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
177+
178+
// Create and return the new SSLSocket object.
179+
return ssl_socket_make_new(self, sock, args[ARG_server_side].u_bool,
180+
args[ARG_do_handshake_on_connect].u_bool, args[ARG_server_hostname].u_obj);
181+
}
182+
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(ssl_context_wrap_socket_obj, 2, ssl_context_wrap_socket);
183+
184+
STATIC const mp_rom_map_elem_t ssl_context_locals_dict_table[] = {
185+
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&ssl_context_wrap_socket_obj) },
186+
};
187+
STATIC MP_DEFINE_CONST_DICT(ssl_context_locals_dict, ssl_context_locals_dict_table);
188+
189+
STATIC MP_DEFINE_CONST_OBJ_TYPE(
190+
ssl_context_type,
191+
MP_QSTR_SSLContext,
192+
MP_TYPE_FLAG_NONE,
193+
make_new, ssl_context_make_new,
194+
locals_dict, &ssl_context_locals_dict
195+
);
196+
197+
/******************************************************************************/
198+
// SSLSocket type.
199+
200+
STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock,
201+
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname) {
119202

120-
STATIC mp_obj_ssl_socket_t *ssl_socket_new(mp_obj_t sock, struct ssl_args *args) {
121203
#if MICROPY_PY_SSL_FINALISER
122204
mp_obj_ssl_socket_t *o = m_new_obj_with_finaliser(mp_obj_ssl_socket_t);
123205
#else
@@ -130,43 +212,43 @@ STATIC mp_obj_ssl_socket_t *ssl_socket_new(mp_obj_t sock, struct ssl_args *args)
130212
o->blocking = true;
131213

132214
uint32_t options = SSL_SERVER_VERIFY_LATER;
133-
if (!args->do_handshake.u_bool) {
215+
if (!do_handshake_on_connect) {
134216
options |= SSL_CONNECT_IN_PARTS;
135217
}
136-
if (args->key.u_obj != mp_const_none) {
218+
if (ssl_context->key != mp_const_none) {
137219
options |= SSL_NO_DEFAULT_KEY;
138220
}
139221
if ((o->ssl_ctx = ssl_ctx_new(options, SSL_DEFAULT_CLNT_SESS)) == NULL) {
140222
mp_raise_OSError(MP_EINVAL);
141223
}
142224

143-
if (args->key.u_obj != mp_const_none) {
225+
if (ssl_context->key != mp_const_none) {
144226
size_t len;
145-
const byte *data = (const byte *)mp_obj_str_get_data(args->key.u_obj, &len);
227+
const byte *data = (const byte *)mp_obj_str_get_data(ssl_context->key, &len);
146228
int res = ssl_obj_memory_load(o->ssl_ctx, SSL_OBJ_RSA_KEY, data, len, NULL);
147229
if (res != SSL_OK) {
148230
mp_raise_ValueError(MP_ERROR_TEXT("invalid key"));
149231
}
150232

151-
data = (const byte *)mp_obj_str_get_data(args->cert.u_obj, &len);
233+
data = (const byte *)mp_obj_str_get_data(ssl_context->cert, &len);
152234
res = ssl_obj_memory_load(o->ssl_ctx, SSL_OBJ_X509_CERT, data, len, NULL);
153235
if (res != SSL_OK) {
154236
mp_raise_ValueError(MP_ERROR_TEXT("invalid cert"));
155237
}
156238
}
157239

158-
if (args->server_side.u_bool) {
240+
if (server_side) {
159241
o->ssl_sock = ssl_server_new(o->ssl_ctx, (long)sock);
160242
} else {
161243
SSL_EXTENSIONS *ext = ssl_ext_new();
162244

163-
if (args->server_hostname.u_obj != mp_const_none) {
164-
ext->host_name = (char *)mp_obj_str_get_str(args->server_hostname.u_obj);
245+
if (server_hostname != mp_const_none) {
246+
ext->host_name = (char *)mp_obj_str_get_str(server_hostname);
165247
}
166248

167249
o->ssl_sock = ssl_client_new(o->ssl_ctx, (long)sock, NULL, 0, ext);
168250

169-
if (args->do_handshake.u_bool) {
251+
if (do_handshake_on_connect) {
170252
int r = ssl_handshake_status(o->ssl_sock);
171253

172254
if (r != SSL_OK) {
@@ -178,18 +260,11 @@ STATIC mp_obj_ssl_socket_t *ssl_socket_new(mp_obj_t sock, struct ssl_args *args)
178260
ssl_raise_error(r);
179261
}
180262
}
181-
182263
}
183264

184265
return o;
185266
}
186267

187-
STATIC void ssl_socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kind_t kind) {
188-
(void)kind;
189-
mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in);
190-
mp_printf(print, "<_SSLSocket %p>", self->ssl_sock);
191-
}
192-
193268
STATIC mp_uint_t ssl_socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) {
194269
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
195270

@@ -305,7 +380,6 @@ STATIC const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = {
305380
{ MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&mp_stream_close_obj) },
306381
#endif
307382
};
308-
309383
STATIC MP_DEFINE_CONST_DICT(ssl_socket_locals_dict, ssl_socket_locals_dict_table);
310384

311385
STATIC const mp_stream_p_t ssl_socket_stream_p = {
@@ -316,16 +390,23 @@ STATIC const mp_stream_p_t ssl_socket_stream_p = {
316390

317391
STATIC MP_DEFINE_CONST_OBJ_TYPE(
318392
ssl_socket_type,
319-
// Save on qstr's, reuse same as for module
320-
MP_QSTR_ssl,
393+
MP_QSTR_SSLSocket,
321394
MP_TYPE_FLAG_NONE,
322-
print, ssl_socket_print,
323395
protocol, &ssl_socket_stream_p,
324396
locals_dict, &ssl_socket_locals_dict
325397
);
326398

399+
/******************************************************************************/
400+
// ssl module.
401+
327402
STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
328-
// TODO: Implement more args
403+
enum {
404+
ARG_key,
405+
ARG_cert,
406+
ARG_server_side,
407+
ARG_server_hostname,
408+
ARG_do_handshake,
409+
};
329410
static const mp_arg_t allowed_args[] = {
330411
{ MP_QSTR_key, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
331412
{ MP_QSTR_cert, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
@@ -334,22 +415,40 @@ STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_
334415
{ MP_QSTR_do_handshake, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} },
335416
};
336417

337-
// TODO: Check that sock implements stream protocol
418+
// Parse arguments.
338419
mp_obj_t sock = pos_args[0];
420+
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
421+
mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
422+
423+
// Create SSLContext.
424+
mp_int_t protocol = args[ARG_server_side].u_bool ? PROTOCOL_TLS_SERVER : PROTOCOL_TLS_CLIENT;
425+
mp_obj_t ssl_context_args[1] = { MP_OBJ_NEW_SMALL_INT(protocol) };
426+
mp_obj_ssl_context_t *ssl_context = MP_OBJ_TO_PTR(ssl_context_make_new(&ssl_context_type, 1, 0, ssl_context_args));
339427

340-
struct ssl_args args;
341-
mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args,
342-
MP_ARRAY_SIZE(allowed_args), allowed_args, (mp_arg_val_t *)&args);
428+
// Load key and cert if given.
429+
if (args[ARG_key].u_obj != mp_const_none) {
430+
ssl_context_load_key(ssl_context, args[ARG_key].u_obj, args[ARG_cert].u_obj);
431+
}
343432

344-
return MP_OBJ_FROM_PTR(ssl_socket_new(sock, &args));
433+
// Create and return the new SSLSocket object.
434+
return ssl_socket_make_new(ssl_context, sock, args[ARG_server_side].u_bool,
435+
args[ARG_do_handshake].u_bool, args[ARG_server_hostname].u_obj);
345436
}
346437
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socket);
347438

348439
STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table[] = {
349440
{ MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ssl) },
441+
442+
// Functions.
350443
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&mod_ssl_wrap_socket_obj) },
351-
};
352444

445+
// Classes.
446+
{ MP_ROM_QSTR(MP_QSTR_SSLContext), MP_ROM_PTR(&ssl_context_type) },
447+
448+
// Constants.
449+
{ MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_CLIENT), MP_ROM_INT(PROTOCOL_TLS_CLIENT) },
450+
{ MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_SERVER), MP_ROM_INT(PROTOCOL_TLS_SERVER) },
451+
};
353452
STATIC MP_DEFINE_CONST_DICT(mp_module_ssl_globals, mp_module_ssl_globals_table);
354453

355454
const mp_obj_module_t mp_module_ssl = {

0 commit comments

Comments
 (0)