diff --git a/ext/openssl/ossl_ssl.c b/ext/openssl/ossl_ssl.c index 7a297fd9a..6de1cd790 100644 --- a/ext/openssl/ossl_ssl.c +++ b/ext/openssl/ossl_ssl.c @@ -1551,7 +1551,11 @@ static void ossl_ssl_mark(void *ptr) { SSL *ssl = ptr; - rb_gc_mark((VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx)); + VALUE obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); + + // Ensure GC compaction won't move objects referenced by OpenSSL objects + rb_gc_mark(obj); + rb_gc_mark(rb_attr_get(obj, id_i_io)); } static void @@ -1601,13 +1605,29 @@ peeraddr_ip_str(VALUE self) return rb_rescue2(peer_ip_address, self, fallback_peer_ip_address, (VALUE)0, rb_eSystemCallError, NULL); } +static int +is_real_socket(VALUE io) +{ + // FIXME: DO NOT MERGE + return 0; + return RB_TYPE_P(io, T_FILE); +} + /* * call-seq: * SSLSocket.new(io) => aSSLSocket * SSLSocket.new(io, ctx) => aSSLSocket * - * Creates a new SSL socket from _io_ which must be a real IO object (not an - * IO-like object that responds to read/write). + * Creates a new SSL socket from _io_ which must be an IO object + * or an IO-like object that at least implements the following methods: + * + * - write_nonblock with exception: false + * - read_nonblock with exception: false + * - wait_readable + * - wait_writable + * - flush + * - close + * - closed? * * If _ctx_ is provided the SSL Sockets initial params will be taken from * the context. @@ -1635,9 +1655,18 @@ ossl_ssl_initialize(int argc, VALUE *argv, VALUE self) rb_ivar_set(self, id_i_context, v_ctx); ossl_sslctx_setup(v_ctx); - if (rb_respond_to(io, rb_intern("nonblock="))) - rb_funcall(io, rb_intern("nonblock="), 1, Qtrue); - Check_Type(io, T_FILE); + if (is_real_socket(io)) { + rb_io_t *fptr; + GetOpenFile(io, fptr); + rb_io_set_nonblock(fptr); + } + else { + // Not meant to be a comprehensive check + if (!rb_respond_to(io, rb_intern("read_nonblock")) || + !rb_respond_to(io, rb_intern("write_nonblock"))) + rb_raise(rb_eTypeError, "io must be a real IO object or an IO-like " + "object that responds to read_nonblock and write_nonblock"); + } rb_ivar_set(self, id_i_io, io); ssl = SSL_new(ctx); @@ -1669,18 +1698,24 @@ ossl_ssl_setup(VALUE self) { VALUE io; SSL *ssl; - rb_io_t *fptr; GetSSL(self, ssl); if (ssl_started(ssl)) return Qtrue; io = rb_attr_get(self, id_i_io); - GetOpenFile(io, fptr); - rb_io_check_readable(fptr); - rb_io_check_writable(fptr); - if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io)))) - ossl_raise(eSSLError, "SSL_set_fd"); + if (is_real_socket(io)) { + rb_io_t *fptr; + GetOpenFile(io, fptr); + rb_io_check_readable(fptr); + rb_io_check_writable(fptr); + if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io)))) + ossl_raise(eSSLError, "SSL_set_fd"); + } + else { + BIO *bio = ossl_bio_new(io); + SSL_set_bio(ssl, bio, bio); + } return Qtrue; } @@ -1691,6 +1726,38 @@ ossl_ssl_setup(VALUE self) #define ssl_get_error(ssl, ret) SSL_get_error((ssl), (ret)) #endif +static void +check_bio_error(SSL *ssl, VALUE io, int ret) +{ + if (is_real_socket(io)) + return; + + BIO *bio = SSL_get_rbio(ssl); + int state = ossl_bio_state(bio); + if (!state) + return; + + /* + * Operation may succeed while the underlying socket reports an error in + * some cases. For example, when TLS 1.3 server tries to send a + * NewSessionTicket on a closed socket (IOW, when the client disconnects + * right after finishing a handshake). + * + * According to ssl/statem/statem_srvr.c conn_is_closed(), EPIPE and + * ECONNRESET may be ignored. + * + * FIXME BEFORE MERGE: Currently ignoring all SystemCallError. + */ + int error_code = SSL_get_error(ssl, ret); + if ((ret > 0 || error_code == SSL_ERROR_ZERO_RETURN || error_code == SSL_ERROR_SSL) && + rb_obj_is_kind_of(rb_errinfo(), rb_eSystemCallError)) { + rb_set_errinfo(Qnil); + return; + } + ossl_clear_error(); + rb_jump_tag(state); +} + static void write_would_block(int nonblock) { @@ -1729,6 +1796,11 @@ no_exception_p(VALUE opts) static void io_wait_writable(VALUE io) { + if (!is_real_socket(io)) { + if (!RTEST(rb_funcallv(io, rb_intern("wait_writable"), 0, NULL))) + rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!"); + return; + } #ifdef HAVE_RB_IO_MAYBE_WAIT if (!rb_io_maybe_wait_writable(errno, io, RUBY_IO_TIMEOUT_DEFAULT)) { rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!"); @@ -1743,6 +1815,11 @@ io_wait_writable(VALUE io) static void io_wait_readable(VALUE io) { + if (!is_real_socket(io)) { + if (!RTEST(rb_funcallv(io, rb_intern("wait_readable"), 0, NULL))) + rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!"); + return; + } #ifdef HAVE_RB_IO_MAYBE_WAIT if (!rb_io_maybe_wait_readable(errno, io, RUBY_IO_TIMEOUT_DEFAULT)) { rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!"); @@ -1767,8 +1844,10 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts) GetSSL(self, ssl); VALUE io = rb_attr_get(self, id_i_io); + for (;;) { ret = func(ssl); + check_bio_error(ssl, io, ret); cb_state = rb_attr_get(self, ID_callback_state); if (!NIL_P(cb_state)) { @@ -1963,6 +2042,8 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) rb_str_locktmp(str); for (;;) { int nread = SSL_read(ssl, RSTRING_PTR(str), ilen); + check_bio_error(ssl, io, nread); + switch (ssl_get_error(ssl, nread)) { case SSL_ERROR_NONE: rb_str_unlocktmp(str); @@ -2067,6 +2148,8 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts) for (;;) { int nwritten = SSL_write(ssl, RSTRING_PTR(tmp), num); + check_bio_error(ssl, io, nwritten); + switch (ssl_get_error(ssl, nwritten)) { case SSL_ERROR_NONE: return INT2NUM(nwritten); @@ -2144,7 +2227,15 @@ ossl_ssl_stop(VALUE self) GetSSL(self, ssl); if (!ssl_started(ssl)) return Qnil; + ret = SSL_shutdown(ssl); + + /* XXX: Suppressing errors from the underlying socket */ + VALUE io = rb_attr_get(self, id_i_io); + BIO *bio = SSL_get_rbio(ssl); + if (!is_real_socket(io) && ossl_bio_state(bio)) + rb_set_errinfo(Qnil); + if (ret == 1) /* Have already received close_notify */ return Qnil; if (ret == 0) /* Sent close_notify, but we don't wait for reply */ diff --git a/test/openssl/test_pair.rb b/test/openssl/test_pair.rb index 10942191d..1664c00c8 100644 --- a/test/openssl/test_pair.rb +++ b/test/openssl/test_pair.rb @@ -67,6 +67,32 @@ def create_tcp_client(host, port) end end +module OpenSSL::SSLPairIOish + include OpenSSL::SSLPairM + + def create_tcp_server(host, port) + Addrinfo.tcp(host, port).listen + end + + class TCPSocketWrapper + def initialize(io) @io = io end + def read_nonblock(*args, **kwargs) @io.read_nonblock(*args, **kwargs) end + def write_nonblock(*args, **kwargs) @io.write_nonblock(*args, **kwargs) end + def wait_readable() @io.wait_readable end + def wait_writable() @io.wait_writable end + def flush() @io.flush end + def close() @io.close end + def closed?() @io.closed? end + + # Only used within test_pair.rb + def write(*args) @io.write(*args) end + end + + def create_tcp_client(host, port) + TCPSocketWrapper.new(Addrinfo.tcp(host, port).connect) + end +end + module OpenSSL::TestEOF1M def open_file(content) ssl_pair { |s1, s2| @@ -518,6 +544,12 @@ class OpenSSL::TestEOF1LowlevelSocket < OpenSSL::TestCase include OpenSSL::TestEOF1M end +class OpenSSL::TestEOF1IOish < OpenSSL::TestCase + include OpenSSL::TestEOF + include OpenSSL::SSLPairIOish + include OpenSSL::TestEOF1M +end + class OpenSSL::TestEOF2 < OpenSSL::TestCase include OpenSSL::TestEOF include OpenSSL::SSLPair @@ -530,6 +562,12 @@ class OpenSSL::TestEOF2LowlevelSocket < OpenSSL::TestCase include OpenSSL::TestEOF2M end +class OpenSSL::TestEOF2IOish < OpenSSL::TestCase + include OpenSSL::TestEOF + include OpenSSL::SSLPairIOish + include OpenSSL::TestEOF2M +end + class OpenSSL::TestPair < OpenSSL::TestCase include OpenSSL::SSLPair include OpenSSL::TestPairM @@ -540,4 +578,9 @@ class OpenSSL::TestPairLowlevelSocket < OpenSSL::TestCase include OpenSSL::TestPairM end +class OpenSSL::TestPairIOish < OpenSSL::TestCase + include OpenSSL::SSLPairIOish + include OpenSSL::TestPairM +end + end diff --git a/test/openssl/test_ssl.rb b/test/openssl/test_ssl.rb index f011e881e..315addc28 100644 --- a/test/openssl/test_ssl.rb +++ b/test/openssl/test_ssl.rb @@ -4,17 +4,6 @@ if defined?(OpenSSL::SSL) class OpenSSL::TestSSL < OpenSSL::SSLTestCase - def test_bad_socket - bad_socket = Struct.new(:sync).new - assert_raise TypeError do - socket = OpenSSL::SSL::SSLSocket.new bad_socket - # if the socket is not a T_FILE, `connect` will segv because it tries - # to get the underlying file descriptor but the API it calls assumes - # the object type is T_FILE - socket.connect - end - end - def test_ctx_options ctx = OpenSSL::SSL::SSLContext.new @@ -141,6 +130,65 @@ def test_socket_close_write end end + def test_synthetic_io_sanity_check + obj = Object.new + assert_raise_with_message(TypeError, /read_nonblock/) { OpenSSL::SSL::SSLSocket.new(obj) } + + obj = Object.new + obj.define_singleton_method(:read_nonblock) { |*args, **kwargs| } + obj.define_singleton_method(:write_nonblock) { |*args, **kwargs| } + assert_nothing_raised { OpenSSL::SSL::SSLSocket.new(obj) } + end + + def test_synthetic_io + start_server do |port| + tcp = TCPSocket.new("127.0.0.1", port) + obj = Object.new + obj.define_singleton_method(:read_nonblock) { |maxlen, exception:| + tcp.read_nonblock(maxlen, exception: exception) } + obj.define_singleton_method(:write_nonblock) { |str, exception:| + tcp.write_nonblock(str, exception: exception) } + obj.define_singleton_method(:wait_readable) { tcp.wait_readable } + obj.define_singleton_method(:wait_writable) { tcp.wait_writable } + obj.define_singleton_method(:flush) { tcp.flush } + obj.define_singleton_method(:closed?) { tcp.closed? } + + ssl = OpenSSL::SSL::SSLSocket.new(obj) + assert_same obj, ssl.to_io + + ssl.connect + ssl.puts "abc"; assert_equal "abc\n", ssl.gets + ensure + ssl&.close + tcp&.close + end + end + + def test_synthetic_io_write_nonblock_exception + start_server(ignore_listener_error: true) do |port| + tcp = TCPSocket.new("127.0.0.1", port) + obj = Object.new + [:read_nonblock, :wait_readable, :wait_writable, :flush, :closed?].each do |name| + obj.define_singleton_method(name) { |*args, **kwargs| + tcp.__send__(name, *args, **kwargs) } + end + + # SSLSocket#connect calls write_nonblock at least twice: ClientHello and Finished + # Let's break the second call + called = 0 + obj.define_singleton_method(:write_nonblock) { |*args, **kwargs| + raise "foo" if (called += 1) == 2 + tcp.write_nonblock(*args, **kwargs) + } + + ssl = OpenSSL::SSL::SSLSocket.new(obj) + assert_raise_with_message(RuntimeError, "foo") { ssl.connect } + ensure + ssl&.close + tcp&.close + end + end + def test_add_certificate ctx_proc = -> ctx { # Unset values set by start_server