diff --git a/ext/openssl/openssl_missing.h b/ext/openssl/openssl_missing.h index 0711f924e..02f24078d 100644 --- a/ext/openssl/openssl_missing.h +++ b/ext/openssl/openssl_missing.h @@ -156,6 +156,26 @@ IMPL_PKEY_GETTER(EC_KEY, ec) #undef IMPL_PKEY_GETTER #undef IMPL_KEY_ACCESSOR2 #undef IMPL_KEY_ACCESSOR3 + +// BIO +static inline void *BIO_get_data(BIO *bio) { return bio->ptr; } +static inline void BIO_set_data(BIO *bio, void *data) { bio->ptr = data; } +static inline void BIO_set_init(BIO *bio, int init) { bio->init = init; } +static inline BIO_METHOD *BIO_meth_new(int type, const char *name) { + BIO_METHOD *meth = OPENSSL_malloc(sizeof(*meth)); + if (!meth) + return NULL; + memset(meth, 0, sizeof(*meth)); + meth->type = type; + meth->name = name; + return meth; +} +static inline void BIO_meth_free(BIO_METHOD *meth) { OPENSSL_free(meth); } +static inline int BIO_meth_set_create(BIO_METHOD *meth, int (*f)(BIO *)) { meth->create = f; return 1; } +static inline int BIO_meth_set_destroy(BIO_METHOD *meth, int (*f)(BIO *)) { meth->destroy = f; return 1; } +static inline int BIO_meth_set_write(BIO_METHOD *meth, int (*f)(BIO *, const char *, int)) { meth->bwrite = f; return 1; } +static inline int BIO_meth_set_read(BIO_METHOD *meth, int (*f)(BIO *, char *, int)) { meth->bread = f; return 1; } +static inline int BIO_meth_set_ctrl(BIO_METHOD *meth, long (*f)(BIO *, int, long, void *)) { meth->ctrl = f; return 1; } #endif /* HAVE_OPAQUE_OPENSSL */ #if !defined(EVP_CTRL_AEAD_GET_TAG) diff --git a/ext/openssl/ossl.c b/ext/openssl/ossl.c index 59ad7d19a..74dd44dd8 100644 --- a/ext/openssl/ossl.c +++ b/ext/openssl/ossl.c @@ -1150,6 +1150,7 @@ Init_openssl(void) /* * Init components */ + Init_ossl_bio(); Init_ossl_bn(); Init_ossl_cipher(); Init_ossl_config(); diff --git a/ext/openssl/ossl_bio.c b/ext/openssl/ossl_bio.c index 2ef208050..318ab5cb1 100644 --- a/ext/openssl/ossl_bio.c +++ b/ext/openssl/ossl_bio.c @@ -40,3 +40,230 @@ ossl_membio2str(BIO *bio) return ret; } + +static BIO_METHOD *ossl_bio_meth; +static VALUE nonblock_kwargs, sym_wait_readable, sym_wait_writable; + +struct ossl_bio_ctx { + VALUE io; + int state; + int eof; +}; + +BIO * +ossl_bio_new(VALUE io) +{ + BIO *bio = BIO_new(ossl_bio_meth); + if (!bio) + ossl_raise(eOSSLError, "BIO_new"); + struct ossl_bio_ctx *ctx = BIO_get_data(bio); + ctx->io = io; + BIO_set_init(bio, 1); + return bio; +} + +int +ossl_bio_state(BIO *bio) +{ + struct ossl_bio_ctx *ctx = BIO_get_data(bio); + int state = ctx->state; + ctx->state = 0; + return state; +} + +static int +bio_create(BIO *bio) +{ + struct ossl_bio_ctx *ctx = OPENSSL_malloc(sizeof(*ctx)); + if (!ctx) + return 0; + memset(ctx, 0, sizeof(*ctx)); + BIO_set_data(bio, ctx); + + return 1; +} + +static int +bio_destroy(BIO *bio) +{ + struct ossl_bio_ctx *ctx = BIO_get_data(bio); + if (ctx) { + OPENSSL_free(ctx); + BIO_set_data(bio, NULL); + } + + return 1; +} + +struct bwrite_args { + BIO *bio; + struct ossl_bio_ctx *ctx; + const char *data; + int dlen; + int written; +}; + +static VALUE +bio_bwrite0(VALUE args) +{ + struct bwrite_args *p = (void *)args; + BIO_clear_retry_flags(p->bio); + + VALUE fargs[] = { rb_str_new_static(p->data, p->dlen), nonblock_kwargs }; + VALUE ret = rb_funcallv_public_kw(p->ctx->io, rb_intern("write_nonblock"), + 2, fargs, RB_PASS_KEYWORDS); + + if (RB_INTEGER_TYPE_P(ret)) { + p->written = NUM2INT(ret); + return Qtrue; + } + else if (ret == sym_wait_readable) { + BIO_set_retry_read(p->bio); + return Qfalse; + } + else if (ret == sym_wait_writable) { + BIO_set_retry_write(p->bio); + return Qfalse; + } + else { + rb_raise(rb_eTypeError, "write_nonblock must return an Integer, " + ":wait_readable, or :wait_writable"); + } +} + +static int +bio_bwrite(BIO *bio, const char *data, int dlen) +{ + struct ossl_bio_ctx *ctx = BIO_get_data(bio); + struct bwrite_args args = { bio, ctx, data, dlen, 0 }; + int state; + + if (ctx->state) + return -1; + + VALUE ok = rb_protect(bio_bwrite0, (VALUE)&args, &state); + if (state) { + ctx->state = state; + return -1; + } + if (RTEST(ok)) + return args.written; + return -1; +} + +struct bread_args { + BIO *bio; + struct ossl_bio_ctx *ctx; + char *data; + int dlen; + int readbytes; +}; + +static VALUE +bio_bread0(VALUE args) +{ + struct bread_args *p = (void *)args; + BIO_clear_retry_flags(p->bio); + + VALUE fargs[] = { INT2NUM(p->dlen), nonblock_kwargs }; + VALUE ret = rb_funcallv_public_kw(p->ctx->io, rb_intern("read_nonblock"), + 2, fargs, RB_PASS_KEYWORDS); + + if (RB_TYPE_P(ret, T_STRING)) { + int len = RSTRING_LENINT(ret); + if (len > p->dlen) + rb_raise(rb_eTypeError, "read_nonblock returned too much data"); + memcpy(p->data, RSTRING_PTR(ret), len); + p->readbytes = len; + return Qtrue; + } + else if (NIL_P(ret)) { + // In OpenSSL 3.0 or later: BIO_set_flags(p->bio, BIO_FLAGS_IN_EOF); + p->ctx->eof = 1; + return Qtrue; + } + else if (ret == sym_wait_readable) { + BIO_set_retry_read(p->bio); + return Qfalse; + } + else if (ret == sym_wait_writable) { + BIO_set_retry_write(p->bio); + return Qfalse; + } + else { + rb_raise(rb_eTypeError, "write_nonblock must return an Integer, " + ":wait_readable, or :wait_writable"); + } +} + +static int +bio_bread(BIO *bio, char *data, int dlen) +{ + struct ossl_bio_ctx *ctx = BIO_get_data(bio); + struct bread_args args = { bio, ctx, data, dlen, 0 }; + int state; + + if (ctx->state) + return -1; + + VALUE ok = rb_protect(bio_bread0, (VALUE)&args, &state); + if (state) { + ctx->state = state; + return -1; + } + if (RTEST(ok)) + return args.readbytes; + return -1; +} + +static VALUE +bio_flush0(VALUE vctx) +{ + struct ossl_bio_ctx *ctx = (void *)vctx; + return rb_funcallv_public(ctx->io, rb_intern("flush"), 0, NULL); +} + +static long +bio_ctrl(BIO *bio, int cmd, long larg, void *parg) +{ + struct ossl_bio_ctx *ctx = BIO_get_data(bio); + int state; + + if (ctx->state) + return 0; + + switch (cmd) { + case BIO_CTRL_EOF: + return ctx->eof; + case BIO_CTRL_FLUSH: + rb_protect(bio_flush0, (VALUE)ctx, &state); + ctx->state = state; + return !state; + default: + return 0; + } +} + +void +Init_ossl_bio(void) +{ + ossl_bio_meth = BIO_meth_new(BIO_TYPE_SOURCE_SINK, "Ruby IO-like object"); + if (!ossl_bio_meth) + ossl_raise(eOSSLError, "BIO_meth_new"); + if (!BIO_meth_set_create(ossl_bio_meth, bio_create) || + !BIO_meth_set_destroy(ossl_bio_meth, bio_destroy) || + !BIO_meth_set_write(ossl_bio_meth, bio_bwrite) || + !BIO_meth_set_read(ossl_bio_meth, bio_bread) || + !BIO_meth_set_ctrl(ossl_bio_meth, bio_ctrl)) { + BIO_meth_free(ossl_bio_meth); + ossl_bio_meth = NULL; + ossl_raise(eOSSLError, "BIO_meth_set_*"); + } + + nonblock_kwargs = rb_hash_new(); + rb_hash_aset(nonblock_kwargs, ID2SYM(rb_intern_const("exception")), Qfalse); + rb_global_variable(&nonblock_kwargs); + + sym_wait_readable = ID2SYM(rb_intern_const("wait_readable")); + sym_wait_writable = ID2SYM(rb_intern_const("wait_writable")); +} diff --git a/ext/openssl/ossl_bio.h b/ext/openssl/ossl_bio.h index 1b871f1cd..71c80ccc7 100644 --- a/ext/openssl/ossl_bio.h +++ b/ext/openssl/ossl_bio.h @@ -13,4 +13,9 @@ BIO *ossl_obj2bio(volatile VALUE *); VALUE ossl_membio2str(BIO*); +BIO *ossl_bio_new(VALUE io); +int ossl_bio_state(BIO *bio); + +void Init_ossl_bio(void); + #endif diff --git a/ext/openssl/ossl_ssl.c b/ext/openssl/ossl_ssl.c index 457630ddc..6de1cd790 100644 --- a/ext/openssl/ossl_ssl.c +++ b/ext/openssl/ossl_ssl.c @@ -55,7 +55,6 @@ static ID id_i_cert_store, id_i_ca_file, id_i_ca_path, id_i_verify_mode, id_i_verify_hostname, id_i_keylog_cb; static ID id_i_io, id_i_context, id_i_hostname; -static int ossl_ssl_ex_vcb_idx; static int ossl_ssl_ex_ptr_idx; static int ossl_sslctx_ex_ptr_idx; @@ -327,9 +326,9 @@ ossl_ssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) int status; ssl = X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()); - cb = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_vcb_idx); ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); sslctx_obj = rb_attr_get(ssl_obj, id_i_context); + cb = rb_attr_get(sslctx_obj, id_i_verify_callback); verify_hostname = rb_attr_get(sslctx_obj, id_i_verify_hostname); if (preverify_ok && RTEST(verify_hostname) && !SSL_is_server(ssl) && @@ -1552,12 +1551,11 @@ static void ossl_ssl_mark(void *ptr) { SSL *ssl = ptr; - rb_gc_mark((VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx)); + VALUE obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); - // Note: this reference is stored as @verify_callback so we don't need to mark it. - // However we do need to ensure GC compaction won't move it, hence why - // we call rb_gc_mark here. - rb_gc_mark((VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_vcb_idx)); + // Ensure GC compaction won't move objects referenced by OpenSSL objects + rb_gc_mark(obj); + rb_gc_mark(rb_attr_get(obj, id_i_io)); } static void @@ -1583,7 +1581,11 @@ ossl_ssl_s_alloc(VALUE klass) static VALUE peer_ip_address(VALUE self) { - VALUE remote_address = rb_funcall(rb_attr_get(self, id_i_io), rb_intern("remote_address"), 0); + VALUE io = rb_attr_get(self, id_i_io); + if (!rb_respond_to(io, rb_intern("remote_address"))) + return rb_str_new_cstr("(unsupported)"); + + VALUE remote_address = rb_funcall(io, rb_intern("remote_address"), 0); return rb_funcall(remote_address, rb_intern("inspect_sockaddr"), 0); } @@ -1603,13 +1605,29 @@ peeraddr_ip_str(VALUE self) return rb_rescue2(peer_ip_address, self, fallback_peer_ip_address, (VALUE)0, rb_eSystemCallError, NULL); } +static int +is_real_socket(VALUE io) +{ + // FIXME: DO NOT MERGE + return 0; + return RB_TYPE_P(io, T_FILE); +} + /* * call-seq: * SSLSocket.new(io) => aSSLSocket * SSLSocket.new(io, ctx) => aSSLSocket * - * Creates a new SSL socket from _io_ which must be a real IO object (not an - * IO-like object that responds to read/write). + * Creates a new SSL socket from _io_ which must be an IO object + * or an IO-like object that at least implements the following methods: + * + * - write_nonblock with exception: false + * - read_nonblock with exception: false + * - wait_readable + * - wait_writable + * - flush + * - close + * - closed? * * If _ctx_ is provided the SSL Sockets initial params will be taken from * the context. @@ -1622,7 +1640,7 @@ peeraddr_ip_str(VALUE self) static VALUE ossl_ssl_initialize(int argc, VALUE *argv, VALUE self) { - VALUE io, v_ctx, verify_cb; + VALUE io, v_ctx; SSL *ssl; SSL_CTX *ctx; @@ -1637,9 +1655,18 @@ ossl_ssl_initialize(int argc, VALUE *argv, VALUE self) rb_ivar_set(self, id_i_context, v_ctx); ossl_sslctx_setup(v_ctx); - if (rb_respond_to(io, rb_intern("nonblock="))) - rb_funcall(io, rb_intern("nonblock="), 1, Qtrue); - Check_Type(io, T_FILE); + if (is_real_socket(io)) { + rb_io_t *fptr; + GetOpenFile(io, fptr); + rb_io_set_nonblock(fptr); + } + else { + // Not meant to be a comprehensive check + if (!rb_respond_to(io, rb_intern("read_nonblock")) || + !rb_respond_to(io, rb_intern("write_nonblock"))) + rb_raise(rb_eTypeError, "io must be a real IO object or an IO-like " + "object that responds to read_nonblock and write_nonblock"); + } rb_ivar_set(self, id_i_io, io); ssl = SSL_new(ctx); @@ -1649,10 +1676,6 @@ ossl_ssl_initialize(int argc, VALUE *argv, VALUE self) SSL_set_ex_data(ssl, ossl_ssl_ex_ptr_idx, (void *)self); SSL_set_info_callback(ssl, ssl_info_cb); - verify_cb = rb_attr_get(v_ctx, id_i_verify_callback); - // We don't need to trigger a write barrier because it's already - // an instance variable of this object. - SSL_set_ex_data(ssl, ossl_ssl_ex_vcb_idx, (void *)verify_cb); rb_call_super(0, NULL); @@ -1675,18 +1698,24 @@ ossl_ssl_setup(VALUE self) { VALUE io; SSL *ssl; - rb_io_t *fptr; GetSSL(self, ssl); if (ssl_started(ssl)) return Qtrue; io = rb_attr_get(self, id_i_io); - GetOpenFile(io, fptr); - rb_io_check_readable(fptr); - rb_io_check_writable(fptr); - if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io)))) - ossl_raise(eSSLError, "SSL_set_fd"); + if (is_real_socket(io)) { + rb_io_t *fptr; + GetOpenFile(io, fptr); + rb_io_check_readable(fptr); + rb_io_check_writable(fptr); + if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io)))) + ossl_raise(eSSLError, "SSL_set_fd"); + } + else { + BIO *bio = ossl_bio_new(io); + SSL_set_bio(ssl, bio, bio); + } return Qtrue; } @@ -1697,6 +1726,38 @@ ossl_ssl_setup(VALUE self) #define ssl_get_error(ssl, ret) SSL_get_error((ssl), (ret)) #endif +static void +check_bio_error(SSL *ssl, VALUE io, int ret) +{ + if (is_real_socket(io)) + return; + + BIO *bio = SSL_get_rbio(ssl); + int state = ossl_bio_state(bio); + if (!state) + return; + + /* + * Operation may succeed while the underlying socket reports an error in + * some cases. For example, when TLS 1.3 server tries to send a + * NewSessionTicket on a closed socket (IOW, when the client disconnects + * right after finishing a handshake). + * + * According to ssl/statem/statem_srvr.c conn_is_closed(), EPIPE and + * ECONNRESET may be ignored. + * + * FIXME BEFORE MERGE: Currently ignoring all SystemCallError. + */ + int error_code = SSL_get_error(ssl, ret); + if ((ret > 0 || error_code == SSL_ERROR_ZERO_RETURN || error_code == SSL_ERROR_SSL) && + rb_obj_is_kind_of(rb_errinfo(), rb_eSystemCallError)) { + rb_set_errinfo(Qnil); + return; + } + ossl_clear_error(); + rb_jump_tag(state); +} + static void write_would_block(int nonblock) { @@ -1735,6 +1796,11 @@ no_exception_p(VALUE opts) static void io_wait_writable(VALUE io) { + if (!is_real_socket(io)) { + if (!RTEST(rb_funcallv(io, rb_intern("wait_writable"), 0, NULL))) + rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!"); + return; + } #ifdef HAVE_RB_IO_MAYBE_WAIT if (!rb_io_maybe_wait_writable(errno, io, RUBY_IO_TIMEOUT_DEFAULT)) { rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!"); @@ -1749,6 +1815,11 @@ io_wait_writable(VALUE io) static void io_wait_readable(VALUE io) { + if (!is_real_socket(io)) { + if (!RTEST(rb_funcallv(io, rb_intern("wait_readable"), 0, NULL))) + rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!"); + return; + } #ifdef HAVE_RB_IO_MAYBE_WAIT if (!rb_io_maybe_wait_readable(errno, io, RUBY_IO_TIMEOUT_DEFAULT)) { rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!"); @@ -1773,8 +1844,10 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts) GetSSL(self, ssl); VALUE io = rb_attr_get(self, id_i_io); + for (;;) { ret = func(ssl); + check_bio_error(ssl, io, ret); cb_state = rb_attr_get(self, ID_callback_state); if (!NIL_P(cb_state)) { @@ -1969,6 +2042,8 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) rb_str_locktmp(str); for (;;) { int nread = SSL_read(ssl, RSTRING_PTR(str), ilen); + check_bio_error(ssl, io, nread); + switch (ssl_get_error(ssl, nread)) { case SSL_ERROR_NONE: rb_str_unlocktmp(str); @@ -2056,7 +2131,6 @@ static VALUE ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts) { SSL *ssl; - rb_io_t *fptr; int num, nonblock = opts != Qfalse; VALUE tmp; @@ -2066,7 +2140,6 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts) tmp = rb_str_new_frozen(StringValue(str)); VALUE io = rb_attr_get(self, id_i_io); - GetOpenFile(io, fptr); /* SSL_write(3ssl) manpage states num == 0 is undefined */ num = RSTRING_LENINT(tmp); @@ -2075,6 +2148,8 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts) for (;;) { int nwritten = SSL_write(ssl, RSTRING_PTR(tmp), num); + check_bio_error(ssl, io, nwritten); + switch (ssl_get_error(ssl, nwritten)) { case SSL_ERROR_NONE: return INT2NUM(nwritten); @@ -2152,7 +2227,15 @@ ossl_ssl_stop(VALUE self) GetSSL(self, ssl); if (!ssl_started(ssl)) return Qnil; + ret = SSL_shutdown(ssl); + + /* XXX: Suppressing errors from the underlying socket */ + VALUE io = rb_attr_get(self, id_i_io); + BIO *bio = SSL_get_rbio(ssl); + if (!is_real_socket(io) && ossl_bio_state(bio)) + rb_set_errinfo(Qnil); + if (ret == 1) /* Have already received close_notify */ return Qnil; if (ret == 0) /* Sent close_notify, but we don't wait for reply */ @@ -2603,9 +2686,6 @@ Init_ossl_ssl(void) id_call = rb_intern_const("call"); ID_callback_state = rb_intern_const("callback_state"); - ossl_ssl_ex_vcb_idx = SSL_get_ex_new_index(0, (void *)"ossl_ssl_ex_vcb_idx", 0, 0, 0); - if (ossl_ssl_ex_vcb_idx < 0) - ossl_raise(rb_eRuntimeError, "SSL_get_ex_new_index"); ossl_ssl_ex_ptr_idx = SSL_get_ex_new_index(0, (void *)"ossl_ssl_ex_ptr_idx", 0, 0, 0); if (ossl_ssl_ex_ptr_idx < 0) ossl_raise(rb_eRuntimeError, "SSL_get_ex_new_index"); diff --git a/lib/openssl/buffering.rb b/lib/openssl/buffering.rb index 85f593af0..9f5e8c0f2 100644 --- a/lib/openssl/buffering.rb +++ b/lib/openssl/buffering.rb @@ -64,7 +64,7 @@ def initialize(*) super @eof = false @rbuffer = Buffer.new - @sync = @io.sync + @sync = @io.respond_to?(:sync) ? @io.sync : true end # diff --git a/test/openssl/test_pair.rb b/test/openssl/test_pair.rb index 10942191d..1664c00c8 100644 --- a/test/openssl/test_pair.rb +++ b/test/openssl/test_pair.rb @@ -67,6 +67,32 @@ def create_tcp_client(host, port) end end +module OpenSSL::SSLPairIOish + include OpenSSL::SSLPairM + + def create_tcp_server(host, port) + Addrinfo.tcp(host, port).listen + end + + class TCPSocketWrapper + def initialize(io) @io = io end + def read_nonblock(*args, **kwargs) @io.read_nonblock(*args, **kwargs) end + def write_nonblock(*args, **kwargs) @io.write_nonblock(*args, **kwargs) end + def wait_readable() @io.wait_readable end + def wait_writable() @io.wait_writable end + def flush() @io.flush end + def close() @io.close end + def closed?() @io.closed? end + + # Only used within test_pair.rb + def write(*args) @io.write(*args) end + end + + def create_tcp_client(host, port) + TCPSocketWrapper.new(Addrinfo.tcp(host, port).connect) + end +end + module OpenSSL::TestEOF1M def open_file(content) ssl_pair { |s1, s2| @@ -518,6 +544,12 @@ class OpenSSL::TestEOF1LowlevelSocket < OpenSSL::TestCase include OpenSSL::TestEOF1M end +class OpenSSL::TestEOF1IOish < OpenSSL::TestCase + include OpenSSL::TestEOF + include OpenSSL::SSLPairIOish + include OpenSSL::TestEOF1M +end + class OpenSSL::TestEOF2 < OpenSSL::TestCase include OpenSSL::TestEOF include OpenSSL::SSLPair @@ -530,6 +562,12 @@ class OpenSSL::TestEOF2LowlevelSocket < OpenSSL::TestCase include OpenSSL::TestEOF2M end +class OpenSSL::TestEOF2IOish < OpenSSL::TestCase + include OpenSSL::TestEOF + include OpenSSL::SSLPairIOish + include OpenSSL::TestEOF2M +end + class OpenSSL::TestPair < OpenSSL::TestCase include OpenSSL::SSLPair include OpenSSL::TestPairM @@ -540,4 +578,9 @@ class OpenSSL::TestPairLowlevelSocket < OpenSSL::TestCase include OpenSSL::TestPairM end +class OpenSSL::TestPairIOish < OpenSSL::TestCase + include OpenSSL::SSLPairIOish + include OpenSSL::TestPairM +end + end diff --git a/test/openssl/test_ssl.rb b/test/openssl/test_ssl.rb index f011e881e..315addc28 100644 --- a/test/openssl/test_ssl.rb +++ b/test/openssl/test_ssl.rb @@ -4,17 +4,6 @@ if defined?(OpenSSL::SSL) class OpenSSL::TestSSL < OpenSSL::SSLTestCase - def test_bad_socket - bad_socket = Struct.new(:sync).new - assert_raise TypeError do - socket = OpenSSL::SSL::SSLSocket.new bad_socket - # if the socket is not a T_FILE, `connect` will segv because it tries - # to get the underlying file descriptor but the API it calls assumes - # the object type is T_FILE - socket.connect - end - end - def test_ctx_options ctx = OpenSSL::SSL::SSLContext.new @@ -141,6 +130,65 @@ def test_socket_close_write end end + def test_synthetic_io_sanity_check + obj = Object.new + assert_raise_with_message(TypeError, /read_nonblock/) { OpenSSL::SSL::SSLSocket.new(obj) } + + obj = Object.new + obj.define_singleton_method(:read_nonblock) { |*args, **kwargs| } + obj.define_singleton_method(:write_nonblock) { |*args, **kwargs| } + assert_nothing_raised { OpenSSL::SSL::SSLSocket.new(obj) } + end + + def test_synthetic_io + start_server do |port| + tcp = TCPSocket.new("127.0.0.1", port) + obj = Object.new + obj.define_singleton_method(:read_nonblock) { |maxlen, exception:| + tcp.read_nonblock(maxlen, exception: exception) } + obj.define_singleton_method(:write_nonblock) { |str, exception:| + tcp.write_nonblock(str, exception: exception) } + obj.define_singleton_method(:wait_readable) { tcp.wait_readable } + obj.define_singleton_method(:wait_writable) { tcp.wait_writable } + obj.define_singleton_method(:flush) { tcp.flush } + obj.define_singleton_method(:closed?) { tcp.closed? } + + ssl = OpenSSL::SSL::SSLSocket.new(obj) + assert_same obj, ssl.to_io + + ssl.connect + ssl.puts "abc"; assert_equal "abc\n", ssl.gets + ensure + ssl&.close + tcp&.close + end + end + + def test_synthetic_io_write_nonblock_exception + start_server(ignore_listener_error: true) do |port| + tcp = TCPSocket.new("127.0.0.1", port) + obj = Object.new + [:read_nonblock, :wait_readable, :wait_writable, :flush, :closed?].each do |name| + obj.define_singleton_method(name) { |*args, **kwargs| + tcp.__send__(name, *args, **kwargs) } + end + + # SSLSocket#connect calls write_nonblock at least twice: ClientHello and Finished + # Let's break the second call + called = 0 + obj.define_singleton_method(:write_nonblock) { |*args, **kwargs| + raise "foo" if (called += 1) == 2 + tcp.write_nonblock(*args, **kwargs) + } + + ssl = OpenSSL::SSL::SSLSocket.new(obj) + assert_raise_with_message(RuntimeError, "foo") { ssl.connect } + ensure + ssl&.close + tcp&.close + end + end + def test_add_certificate ctx_proc = -> ctx { # Unset values set by start_server