4
4
* The MIT License (MIT)
5
5
*
6
6
* Copyright (c) 2015-2019 Paul Sokolovsky
7
+ * Copyright (c) 2023 Damien P. George
7
8
*
8
9
* Permission is hereby granted, free of charge, to any person obtaining a copy
9
10
* of this software and associated documentation files (the "Software"), to deal
35
36
36
37
#include "ssl.h"
37
38
39
+ #define PROTOCOL_TLS_CLIENT (0)
40
+ #define PROTOCOL_TLS_SERVER (1)
41
+
42
+ // This corresponds to an SSLContext object.
43
+ typedef struct _mp_obj_ssl_context_t {
44
+ mp_obj_base_t base ;
45
+ mp_obj_t key ;
46
+ mp_obj_t cert ;
47
+ } mp_obj_ssl_context_t ;
48
+
49
+ // This corresponds to an SSLSocket object.
38
50
typedef struct _mp_obj_ssl_socket_t {
39
51
mp_obj_base_t base ;
40
52
mp_obj_t sock ;
@@ -53,8 +65,15 @@ struct ssl_args {
53
65
mp_arg_val_t do_handshake ;
54
66
};
55
67
68
+ STATIC const mp_obj_type_t ssl_context_type ;
56
69
STATIC const mp_obj_type_t ssl_socket_type ;
57
70
71
+ STATIC mp_obj_t ssl_socket_make_new (mp_obj_ssl_context_t * ssl_context , mp_obj_t sock ,
72
+ bool server_side , bool do_handshake_on_connect , mp_obj_t server_hostname );
73
+
74
+ /******************************************************************************/
75
+ // Helper functions.
76
+
58
77
// Table of error strings corresponding to SSL_xxx error codes.
59
78
STATIC const char * const ssl_error_tab1 [] = {
60
79
"NOT_OK" ,
@@ -116,8 +135,71 @@ STATIC NORETURN void ssl_raise_error(int err) {
116
135
nlr_raise (mp_obj_exception_make_new (& mp_type_OSError , 2 , 0 , args ));
117
136
}
118
137
138
+ /******************************************************************************/
139
+ // SSLContext type.
140
+
141
+ STATIC mp_obj_t ssl_context_make_new (const mp_obj_type_t * type_in , size_t n_args , size_t n_kw , const mp_obj_t * args ) {
142
+ mp_arg_check_num (n_args , n_kw , 1 , 1 , false);
143
+
144
+ // The "protocol" argument is ignored in this implementation.
145
+
146
+ // Create SSLContext object.
147
+ #if MICROPY_PY_SSL_FINALISER
148
+ mp_obj_ssl_context_t * self = m_new_obj_with_finaliser (mp_obj_ssl_context_t );
149
+ #else
150
+ mp_obj_ssl_context_t * self = m_new_obj (mp_obj_ssl_context_t );
151
+ #endif
152
+ self -> base .type = type_in ;
153
+ self -> key = mp_const_none ;
154
+ self -> cert = mp_const_none ;
155
+
156
+ return MP_OBJ_FROM_PTR (self );
157
+ }
158
+
159
+ STATIC void ssl_context_load_key (mp_obj_ssl_context_t * self , mp_obj_t key_obj , mp_obj_t cert_obj ) {
160
+ self -> key = key_obj ;
161
+ self -> cert = cert_obj ;
162
+ }
163
+
164
+ STATIC mp_obj_t ssl_context_wrap_socket (size_t n_args , const mp_obj_t * pos_args , mp_map_t * kw_args ) {
165
+ enum { ARG_server_side , ARG_do_handshake_on_connect , ARG_server_hostname };
166
+ static const mp_arg_t allowed_args [] = {
167
+ { MP_QSTR_server_side , MP_ARG_KW_ONLY | MP_ARG_BOOL , {.u_bool = false} },
168
+ { MP_QSTR_do_handshake_on_connect , MP_ARG_KW_ONLY | MP_ARG_BOOL , {.u_bool = true} },
169
+ { MP_QSTR_server_hostname , MP_ARG_KW_ONLY | MP_ARG_OBJ , {.u_rom_obj = MP_ROM_NONE } },
170
+ };
171
+
172
+ // Parse arguments.
173
+ mp_obj_ssl_context_t * self = MP_OBJ_TO_PTR (pos_args [0 ]);
174
+ mp_obj_t sock = pos_args [1 ];
175
+ mp_arg_val_t args [MP_ARRAY_SIZE (allowed_args )];
176
+ mp_arg_parse_all (n_args - 2 , pos_args + 2 , kw_args , MP_ARRAY_SIZE (allowed_args ), allowed_args , args );
177
+
178
+ // Create and return the new SSLSocket object.
179
+ return ssl_socket_make_new (self , sock , args [ARG_server_side ].u_bool ,
180
+ args [ARG_do_handshake_on_connect ].u_bool , args [ARG_server_hostname ].u_obj );
181
+ }
182
+ STATIC MP_DEFINE_CONST_FUN_OBJ_KW (ssl_context_wrap_socket_obj , 2 , ssl_context_wrap_socket );
183
+
184
+ STATIC const mp_rom_map_elem_t ssl_context_locals_dict_table [] = {
185
+ { MP_ROM_QSTR (MP_QSTR_wrap_socket ), MP_ROM_PTR (& ssl_context_wrap_socket_obj ) },
186
+ };
187
+ STATIC MP_DEFINE_CONST_DICT (ssl_context_locals_dict , ssl_context_locals_dict_table );
188
+
189
+ STATIC MP_DEFINE_CONST_OBJ_TYPE (
190
+ ssl_context_type ,
191
+ MP_QSTR_SSLContext ,
192
+ MP_TYPE_FLAG_NONE ,
193
+ make_new , ssl_context_make_new ,
194
+ locals_dict , & ssl_context_locals_dict
195
+ );
196
+
197
+ /******************************************************************************/
198
+ // SSLSocket type.
199
+
200
+ STATIC mp_obj_t ssl_socket_make_new (mp_obj_ssl_context_t * ssl_context , mp_obj_t sock ,
201
+ bool server_side , bool do_handshake_on_connect , mp_obj_t server_hostname ) {
119
202
120
- STATIC mp_obj_ssl_socket_t * ssl_socket_new (mp_obj_t sock , struct ssl_args * args ) {
121
203
#if MICROPY_PY_SSL_FINALISER
122
204
mp_obj_ssl_socket_t * o = m_new_obj_with_finaliser (mp_obj_ssl_socket_t );
123
205
#else
@@ -130,43 +212,43 @@ STATIC mp_obj_ssl_socket_t *ssl_socket_new(mp_obj_t sock, struct ssl_args *args)
130
212
o -> blocking = true;
131
213
132
214
uint32_t options = SSL_SERVER_VERIFY_LATER ;
133
- if (!args -> do_handshake . u_bool ) {
215
+ if (!do_handshake_on_connect ) {
134
216
options |= SSL_CONNECT_IN_PARTS ;
135
217
}
136
- if (args -> key . u_obj != mp_const_none ) {
218
+ if (ssl_context -> key != mp_const_none ) {
137
219
options |= SSL_NO_DEFAULT_KEY ;
138
220
}
139
221
if ((o -> ssl_ctx = ssl_ctx_new (options , SSL_DEFAULT_CLNT_SESS )) == NULL ) {
140
222
mp_raise_OSError (MP_EINVAL );
141
223
}
142
224
143
- if (args -> key . u_obj != mp_const_none ) {
225
+ if (ssl_context -> key != mp_const_none ) {
144
226
size_t len ;
145
- const byte * data = (const byte * )mp_obj_str_get_data (args -> key . u_obj , & len );
227
+ const byte * data = (const byte * )mp_obj_str_get_data (ssl_context -> key , & len );
146
228
int res = ssl_obj_memory_load (o -> ssl_ctx , SSL_OBJ_RSA_KEY , data , len , NULL );
147
229
if (res != SSL_OK ) {
148
230
mp_raise_ValueError (MP_ERROR_TEXT ("invalid key" ));
149
231
}
150
232
151
- data = (const byte * )mp_obj_str_get_data (args -> cert . u_obj , & len );
233
+ data = (const byte * )mp_obj_str_get_data (ssl_context -> cert , & len );
152
234
res = ssl_obj_memory_load (o -> ssl_ctx , SSL_OBJ_X509_CERT , data , len , NULL );
153
235
if (res != SSL_OK ) {
154
236
mp_raise_ValueError (MP_ERROR_TEXT ("invalid cert" ));
155
237
}
156
238
}
157
239
158
- if (args -> server_side . u_bool ) {
240
+ if (server_side ) {
159
241
o -> ssl_sock = ssl_server_new (o -> ssl_ctx , (long )sock );
160
242
} else {
161
243
SSL_EXTENSIONS * ext = ssl_ext_new ();
162
244
163
- if (args -> server_hostname . u_obj != mp_const_none ) {
164
- ext -> host_name = (char * )mp_obj_str_get_str (args -> server_hostname . u_obj );
245
+ if (server_hostname != mp_const_none ) {
246
+ ext -> host_name = (char * )mp_obj_str_get_str (server_hostname );
165
247
}
166
248
167
249
o -> ssl_sock = ssl_client_new (o -> ssl_ctx , (long )sock , NULL , 0 , ext );
168
250
169
- if (args -> do_handshake . u_bool ) {
251
+ if (do_handshake_on_connect ) {
170
252
int r = ssl_handshake_status (o -> ssl_sock );
171
253
172
254
if (r != SSL_OK ) {
@@ -178,18 +260,11 @@ STATIC mp_obj_ssl_socket_t *ssl_socket_new(mp_obj_t sock, struct ssl_args *args)
178
260
ssl_raise_error (r );
179
261
}
180
262
}
181
-
182
263
}
183
264
184
265
return o ;
185
266
}
186
267
187
- STATIC void ssl_socket_print (const mp_print_t * print , mp_obj_t self_in , mp_print_kind_t kind ) {
188
- (void )kind ;
189
- mp_obj_ssl_socket_t * self = MP_OBJ_TO_PTR (self_in );
190
- mp_printf (print , "<_SSLSocket %p>" , self -> ssl_sock );
191
- }
192
-
193
268
STATIC mp_uint_t ssl_socket_read (mp_obj_t o_in , void * buf , mp_uint_t size , int * errcode ) {
194
269
mp_obj_ssl_socket_t * o = MP_OBJ_TO_PTR (o_in );
195
270
@@ -305,7 +380,6 @@ STATIC const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = {
305
380
{ MP_ROM_QSTR (MP_QSTR___del__ ), MP_ROM_PTR (& mp_stream_close_obj ) },
306
381
#endif
307
382
};
308
-
309
383
STATIC MP_DEFINE_CONST_DICT (ssl_socket_locals_dict , ssl_socket_locals_dict_table );
310
384
311
385
STATIC const mp_stream_p_t ssl_socket_stream_p = {
@@ -316,16 +390,23 @@ STATIC const mp_stream_p_t ssl_socket_stream_p = {
316
390
317
391
STATIC MP_DEFINE_CONST_OBJ_TYPE (
318
392
ssl_socket_type ,
319
- // Save on qstr's, reuse same as for module
320
- MP_QSTR_ssl ,
393
+ MP_QSTR_SSLSocket ,
321
394
MP_TYPE_FLAG_NONE ,
322
- print , ssl_socket_print ,
323
395
protocol , & ssl_socket_stream_p ,
324
396
locals_dict , & ssl_socket_locals_dict
325
397
);
326
398
399
+ /******************************************************************************/
400
+ // ssl module.
401
+
327
402
STATIC mp_obj_t mod_ssl_wrap_socket (size_t n_args , const mp_obj_t * pos_args , mp_map_t * kw_args ) {
328
- // TODO: Implement more args
403
+ enum {
404
+ ARG_key ,
405
+ ARG_cert ,
406
+ ARG_server_side ,
407
+ ARG_server_hostname ,
408
+ ARG_do_handshake ,
409
+ };
329
410
static const mp_arg_t allowed_args [] = {
330
411
{ MP_QSTR_key , MP_ARG_KW_ONLY | MP_ARG_OBJ , {.u_rom_obj = MP_ROM_NONE } },
331
412
{ MP_QSTR_cert , MP_ARG_KW_ONLY | MP_ARG_OBJ , {.u_rom_obj = MP_ROM_NONE } },
@@ -334,22 +415,40 @@ STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_
334
415
{ MP_QSTR_do_handshake , MP_ARG_KW_ONLY | MP_ARG_BOOL , {.u_bool = true} },
335
416
};
336
417
337
- // TODO: Check that sock implements stream protocol
418
+ // Parse arguments.
338
419
mp_obj_t sock = pos_args [0 ];
420
+ mp_arg_val_t args [MP_ARRAY_SIZE (allowed_args )];
421
+ mp_arg_parse_all (n_args - 1 , pos_args + 1 , kw_args , MP_ARRAY_SIZE (allowed_args ), allowed_args , args );
422
+
423
+ // Create SSLContext.
424
+ mp_int_t protocol = args [ARG_server_side ].u_bool ? PROTOCOL_TLS_SERVER : PROTOCOL_TLS_CLIENT ;
425
+ mp_obj_t ssl_context_args [1 ] = { MP_OBJ_NEW_SMALL_INT (protocol ) };
426
+ mp_obj_ssl_context_t * ssl_context = MP_OBJ_TO_PTR (ssl_context_make_new (& ssl_context_type , 1 , 0 , ssl_context_args ));
339
427
340
- struct ssl_args args ;
341
- mp_arg_parse_all (n_args - 1 , pos_args + 1 , kw_args ,
342
- MP_ARRAY_SIZE (allowed_args ), allowed_args , (mp_arg_val_t * )& args );
428
+ // Load key and cert if given.
429
+ if (args [ARG_key ].u_obj != mp_const_none ) {
430
+ ssl_context_load_key (ssl_context , args [ARG_key ].u_obj , args [ARG_cert ].u_obj );
431
+ }
343
432
344
- return MP_OBJ_FROM_PTR (ssl_socket_new (sock , & args ));
433
+ // Create and return the new SSLSocket object.
434
+ return ssl_socket_make_new (ssl_context , sock , args [ARG_server_side ].u_bool ,
435
+ args [ARG_do_handshake ].u_bool , args [ARG_server_hostname ].u_obj );
345
436
}
346
437
STATIC MP_DEFINE_CONST_FUN_OBJ_KW (mod_ssl_wrap_socket_obj , 1 , mod_ssl_wrap_socket );
347
438
348
439
STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table [] = {
349
440
{ MP_ROM_QSTR (MP_QSTR___name__ ), MP_ROM_QSTR (MP_QSTR_ssl ) },
441
+
442
+ // Functions.
350
443
{ MP_ROM_QSTR (MP_QSTR_wrap_socket ), MP_ROM_PTR (& mod_ssl_wrap_socket_obj ) },
351
- };
352
444
445
+ // Classes.
446
+ { MP_ROM_QSTR (MP_QSTR_SSLContext ), MP_ROM_PTR (& ssl_context_type ) },
447
+
448
+ // Constants.
449
+ { MP_ROM_QSTR (MP_QSTR_PROTOCOL_TLS_CLIENT ), MP_ROM_INT (PROTOCOL_TLS_CLIENT ) },
450
+ { MP_ROM_QSTR (MP_QSTR_PROTOCOL_TLS_SERVER ), MP_ROM_INT (PROTOCOL_TLS_SERVER ) },
451
+ };
353
452
STATIC MP_DEFINE_CONST_DICT (mp_module_ssl_globals , mp_module_ssl_globals_table );
354
453
355
454
const mp_obj_module_t mp_module_ssl = {
0 commit comments