diff --git a/ext/openssl/ossl_bio.c b/ext/openssl/ossl_bio.c index 875465adc..14b3b83a8 100644 --- a/ext/openssl/ossl_bio.c +++ b/ext/openssl/ossl_bio.c @@ -46,7 +46,7 @@ static VALUE nonblock_kwargs, sym_wait_readable, sym_wait_writable; struct ossl_bio_ctx { VALUE io; - int state; + int *state; int eof; }; @@ -103,16 +103,14 @@ ossl_bio_get(VALUE obj) return bio; } -int -ossl_bio_state(VALUE obj) +void +ossl_bio_set_tag_ptr(VALUE obj, int *state) { BIO *bio; TypedData_Get_Struct(obj, BIO, &ossl_bio_type, bio); struct ossl_bio_ctx *ctx = BIO_get_data(bio); - int state = ctx->state; - ctx->state = 0; - return state; + ctx->state = state; } static int @@ -182,12 +180,9 @@ bio_bwrite(BIO *bio, const char *data, int dlen) struct bwrite_args args = { bio, ctx, data, dlen, 0 }; int state; - if (ctx->state) - return -1; - VALUE ok = rb_protect(bio_bwrite0, (VALUE)&args, &state); if (state) { - ctx->state = state; + *ctx->state = state; return -1; } if (RTEST(ok)) @@ -247,12 +242,9 @@ bio_bread(BIO *bio, char *data, int dlen) struct bread_args args = { bio, ctx, data, dlen, 0 }; int state; - if (ctx->state) - return -1; - VALUE ok = rb_protect(bio_bread0, (VALUE)&args, &state); if (state) { - ctx->state = state; + *ctx->state = state; return -1; } if (RTEST(ok)) @@ -273,15 +265,13 @@ bio_ctrl(BIO *bio, int cmd, long larg, void *parg) struct ossl_bio_ctx *ctx = BIO_get_data(bio); int state; - if (ctx->state) - return 0; - switch (cmd) { case BIO_CTRL_EOF: return ctx->eof; case BIO_CTRL_FLUSH: rb_protect(bio_flush0, (VALUE)ctx, &state); - ctx->state = state; + if (state) + *ctx->state = state; return !state; default: return 0; diff --git a/ext/openssl/ossl_bio.h b/ext/openssl/ossl_bio.h index 634f99fae..10f2e2c0a 100644 --- a/ext/openssl/ossl_bio.h +++ b/ext/openssl/ossl_bio.h @@ -15,7 +15,7 @@ VALUE ossl_membio2str(BIO*); VALUE ossl_bio_new(VALUE io); BIO *ossl_bio_get(VALUE obj); -int ossl_bio_state(VALUE obj); +void ossl_bio_set_tag_ptr(VALUE obj, int *state); void Init_ossl_bio(void); diff --git a/ext/openssl/ossl_ssl.c b/ext/openssl/ossl_ssl.c index e0f3e5061..7a7fc35a4 100644 --- a/ext/openssl/ossl_ssl.c +++ b/ext/openssl/ossl_ssl.c @@ -10,6 +10,8 @@ * (See the file 'COPYING'.) */ #include "ossl.h" +static void +ssl_set_jump_tag(SSL *ssl, int state); #ifndef OPENSSL_NO_SOCK #define numberof(ary) (int)(sizeof(ary)/sizeof((ary)[0])) @@ -36,7 +38,7 @@ VALUE cSSLSocket; static VALUE eSSLErrorWaitReadable; static VALUE eSSLErrorWaitWritable; -static ID id_call, ID_callback_state, id_tmp_dh_callback, +static ID id_call, id_tmp_dh_callback, id_npn_protocols_encoded, id_each, id_bio; static VALUE sym_exception, sym_wait_readable, sym_wait_writable; @@ -50,7 +52,7 @@ static ID id_i_cert_store, id_i_ca_file, id_i_ca_path, id_i_verify_mode, id_i_verify_hostname, id_i_keylog_cb; static ID id_i_io, id_i_context, id_i_hostname; -static int ossl_ssl_ex_ptr_idx; +static int ossl_ssl_ex_ptr_idx, ossl_ssl_jump_tag_idx; static int ossl_sslctx_ex_ptr_idx; static void @@ -282,7 +284,7 @@ ossl_tmp_dh_callback(SSL *ssl, int is_export, int keylength) pkey = (EVP_PKEY *)rb_protect(ossl_call_tmp_dh_callback, (VALUE)&args, &state); if (state) { - rb_ivar_set(rb_ssl, ID_callback_state, INT2NUM(state)); + ssl_set_jump_tag(ssl, state); return NULL; } if (!pkey) @@ -330,7 +332,7 @@ ossl_ssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) !X509_STORE_CTX_get_error_depth(ctx)) { ret = rb_protect(call_verify_certificate_identity, (VALUE)ctx, &status); if (status) { - rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(status)); + ssl_set_jump_tag(ssl, status); return 0; } if (ret != Qtrue) { @@ -379,7 +381,7 @@ ossl_sslctx_session_get_cb(SSL *ssl, unsigned char *buf, int len, int *copy) ret_obj = rb_protect(ossl_call_session_get_cb, ary, &state); if (state) { - rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(state)); + ssl_set_jump_tag(ssl, state); return NULL; } if (!rb_obj_is_instance_of(ret_obj, cSSLSession)) @@ -424,9 +426,8 @@ ossl_sslctx_session_new_cb(SSL *ssl, SSL_SESSION *sess) rb_ary_push(ary, sess_obj); rb_protect(ossl_call_session_new_cb, ary, &state); - if (state) { - rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(state)); - } + if (state) + ssl_set_jump_tag(ssl, state); /* * return 0 which means to OpenSSL that the session is still @@ -480,9 +481,8 @@ ossl_sslctx_keylog_cb(const SSL *ssl, const char *line) args.line = line; rb_protect(ossl_call_keylog_cb, (VALUE)&args, &state); - if (state) { - rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(state)); - } + if (state) + ssl_set_jump_tag((SSL *)ssl, state); } #endif @@ -525,13 +525,13 @@ ossl_sslctx_session_remove_cb(SSL_CTX *ctx, SSL_SESSION *sess) rb_ary_push(ary, sess_obj); rb_protect(ossl_call_session_remove_cb, ary, &state); - if (state) { + if (state) + rb_set_errinfo(Qnil); /* the SSL_CTX is frozen, nowhere to save state. there is no common accessor method to check it either. rb_ivar_set(sslctx_obj, ID_callback_state, INT2NUM(state)); */ - } } static VALUE @@ -587,8 +587,7 @@ ssl_servername_cb(SSL *ssl, int *ad, void *arg) rb_protect(ossl_call_servername_cb, (VALUE)ssl, &state); if (state) { - VALUE ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); - rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(state)); + ssl_set_jump_tag(ssl, state); return SSL_TLSEXT_ERR_ALERT_FATAL; } @@ -678,10 +677,8 @@ ssl_npn_select_cb_common(SSL *ssl, VALUE cb, const unsigned char **out, selected = rb_protect(npn_select_cb_common_i, (VALUE)&args, &status); if (status) { - VALUE ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); - - rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(status)); - return SSL_TLSEXT_ERR_ALERT_FATAL; + ssl_set_jump_tag(ssl, status); + return SSL_TLSEXT_ERR_ALERT_FATAL; } *out = (unsigned char *)RSTRING_PTR(selected); @@ -1727,39 +1724,50 @@ ossl_ssl_setup(VALUE self) } static void -check_bio_error(SSL *ssl, VALUE bobj, int ret) +ssl_set_jump_tag_ptr(SSL *ssl, VALUE bobj, int *state) +{ + RUBY_ASSERT(SSL_get_ex_data(ssl, ossl_ssl_jump_tag_idx) == NULL); + if (!SSL_set_ex_data(ssl, ossl_ssl_jump_tag_idx, state)) + ossl_raise(eSSLError, "SSL_set_ex_data"); + if (!NIL_P(bobj)) + ossl_bio_set_tag_ptr(bobj, state); +} + +static void +ssl_clear_jump_tag_ptr(SSL *ssl, VALUE bobj) +{ + if (!SSL_set_ex_data(ssl, ossl_ssl_jump_tag_idx, NULL)) + ossl_raise(eSSLError, "SSL_set_ex_data"); + if (!NIL_P(bobj)) + ossl_bio_set_tag_ptr(bobj, NULL); +} + +static void +ssl_check_jump_tag(SSL *ssl, VALUE bobj, int state, int ret) { - // Socket BIO -> nothing to do + ssl_clear_jump_tag_ptr(ssl, bobj); + if (state) { + if (ret <= 0) { + ossl_clear_error(); + rb_jump_tag(state); + } + rb_set_errinfo(Qnil); + } + if (NIL_P(bobj)) { #ifdef _WIN32 errno = rb_w32_map_errno(WSAGetLastError()); #endif - return; } - - int state = ossl_bio_state(bobj); - if (!state) { + else { errno = 0; - return; } +} - /* - * Operation may succeed while the underlying socket reports an error in - * some cases. For example, when TLS 1.3 server tries to send a - * NewSessionTicket on a closed socket (IOW, when the client disconnects - * right after finishing a handshake). - * - * In OpenSSL 3.4.0, ssl/statem/statem_srvr.c conn_is_closed() ignores - * EPIPE and ECONNRESET. - * - * We can't map the exception to a specific errno - */ - if (rb_obj_is_kind_of(rb_errinfo(), rb_eSystemCallError) && ret > 0) { - rb_set_errinfo(Qnil); - return; - } - ossl_clear_error(); - rb_jump_tag(state); +static void +ssl_set_jump_tag(SSL *ssl, int state) +{ + *(int *)SSL_get_ex_data(ssl, ossl_ssl_jump_tag_idx) = state; } static void @@ -1840,26 +1848,18 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts) { SSL *ssl; int ret, ret2; - VALUE cb_state; + int state = 0; int nonblock = opts != Qfalse; - rb_ivar_set(self, ID_callback_state, Qnil); - GetSSL(self, ssl); VALUE io = rb_attr_get(self, id_i_io), bobj = rb_attr_get(self, id_bio); for (;;) { + ssl_set_jump_tag_ptr(ssl, bobj, &state); ret = func(ssl); - check_bio_error(ssl, bobj, ret); - - cb_state = rb_attr_get(self, ID_callback_state); - if (!NIL_P(cb_state)) { - /* must cleanup OpenSSL error stack before re-raising */ - ossl_clear_error(); - rb_jump_tag(NUM2INT(cb_state)); - } + ssl_check_jump_tag(ssl, bobj, state, ret); if (ret > 0) break; @@ -2014,7 +2014,8 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) { SSL *ssl; int ilen; - VALUE len, str, cb_state; + int state = 0; + VALUE len, str; VALUE opts = Qnil; if (nonblock) { @@ -2047,15 +2048,9 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) rb_str_locktmp(str); for (;;) { + ssl_set_jump_tag_ptr(ssl, bobj, &state); int nread = SSL_read(ssl, RSTRING_PTR(str), ilen); - check_bio_error(ssl, bobj, nread); - - cb_state = rb_attr_get(self, ID_callback_state); - if (!NIL_P(cb_state)) { - rb_ivar_set(self, ID_callback_state, Qnil); - ossl_clear_error(); - rb_jump_tag(NUM2INT(cb_state)); - } + ssl_check_jump_tag(ssl, bobj, state, nread); switch (SSL_get_error(ssl, nread)) { case SSL_ERROR_NONE: @@ -2150,7 +2145,7 @@ ossl_ssl_write_internal_safe(VALUE _args) SSL *ssl; int num, nonblock = opts != Qfalse; - VALUE cb_state; + int state = 0; GetSSL(self, ssl); if (!ssl_started(ssl)) @@ -2165,15 +2160,9 @@ ossl_ssl_write_internal_safe(VALUE _args) bobj = rb_attr_get(self, id_bio); for (;;) { + ssl_set_jump_tag_ptr(ssl, bobj, &state); int nwritten = SSL_write(ssl, RSTRING_PTR(str), num); - check_bio_error(ssl, bobj, nwritten); - - cb_state = rb_attr_get(self, ID_callback_state); - if (!NIL_P(cb_state)) { - rb_ivar_set(self, ID_callback_state, Qnil); - ossl_clear_error(); - rb_jump_tag(NUM2INT(cb_state)); - } + ssl_check_jump_tag(ssl, bobj, state, nwritten); switch (SSL_get_error(ssl, nwritten)) { case SSL_ERROR_NONE: @@ -2271,17 +2260,23 @@ ossl_ssl_stop(VALUE self) { SSL *ssl; int ret; + int state = 0; GetSSL(self, ssl); if (!ssl_started(ssl)) return Qnil; + VALUE bobj = rb_attr_get(self, id_bio); + + ssl_set_jump_tag_ptr(ssl, bobj, &state); ret = SSL_shutdown(ssl); + ssl_clear_jump_tag_ptr(ssl, bobj); /* XXX: Suppressing errors from the underlying socket */ - VALUE bobj = rb_attr_get(self, id_bio); - if (!NIL_P(bobj) && ossl_bio_state(bobj)) + if (state) { + ossl_clear_error(); rb_set_errinfo(Qnil); + } if (ret == 1) /* Have already received close_notify */ return Qnil; @@ -2731,11 +2726,13 @@ Init_ossl_ssl(void) #ifndef OPENSSL_NO_SOCK id_call = rb_intern_const("call"); - ID_callback_state = rb_intern_const("callback_state"); ossl_ssl_ex_ptr_idx = SSL_get_ex_new_index(0, (void *)"ossl_ssl_ex_ptr_idx", 0, 0, 0); if (ossl_ssl_ex_ptr_idx < 0) ossl_raise(rb_eRuntimeError, "SSL_get_ex_new_index"); + ossl_ssl_jump_tag_idx = SSL_get_ex_new_index(0, (void *)"ossl_ssl_jump_tag_idx", 0, 0, 0); + if (ossl_ssl_jump_tag_idx < 0) + ossl_raise(rb_eRuntimeError, "SSL_get_ex_new_index"); ossl_sslctx_ex_ptr_idx = SSL_CTX_get_ex_new_index(0, (void *)"ossl_sslctx_ex_ptr_idx", 0, 0, 0); if (ossl_sslctx_ex_ptr_idx < 0) ossl_raise(rb_eRuntimeError, "SSL_CTX_get_ex_new_index");