diff --git a/ext/openssl/ossl_ssl.c b/ext/openssl/ossl_ssl.c
index 7a297fd9a..6de1cd790 100644
--- a/ext/openssl/ossl_ssl.c
+++ b/ext/openssl/ossl_ssl.c
@@ -1551,7 +1551,11 @@ static void
ossl_ssl_mark(void *ptr)
{
SSL *ssl = ptr;
- rb_gc_mark((VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx));
+ VALUE obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx);
+
+ // Ensure GC compaction won't move objects referenced by OpenSSL objects
+ rb_gc_mark(obj);
+ rb_gc_mark(rb_attr_get(obj, id_i_io));
}
static void
@@ -1601,13 +1605,29 @@ peeraddr_ip_str(VALUE self)
return rb_rescue2(peer_ip_address, self, fallback_peer_ip_address, (VALUE)0, rb_eSystemCallError, NULL);
}
+static int
+is_real_socket(VALUE io)
+{
+ // FIXME: DO NOT MERGE
+ return 0;
+ return RB_TYPE_P(io, T_FILE);
+}
+
/*
* call-seq:
* SSLSocket.new(io) => aSSLSocket
* SSLSocket.new(io, ctx) => aSSLSocket
*
- * Creates a new SSL socket from _io_ which must be a real IO object (not an
- * IO-like object that responds to read/write).
+ * Creates a new SSL socket from _io_ which must be an IO object
+ * or an IO-like object that at least implements the following methods:
+ *
+ * - write_nonblock with exception: false
+ * - read_nonblock with exception: false
+ * - wait_readable
+ * - wait_writable
+ * - flush
+ * - close
+ * - closed?
*
* If _ctx_ is provided the SSL Sockets initial params will be taken from
* the context.
@@ -1635,9 +1655,18 @@ ossl_ssl_initialize(int argc, VALUE *argv, VALUE self)
rb_ivar_set(self, id_i_context, v_ctx);
ossl_sslctx_setup(v_ctx);
- if (rb_respond_to(io, rb_intern("nonblock=")))
- rb_funcall(io, rb_intern("nonblock="), 1, Qtrue);
- Check_Type(io, T_FILE);
+ if (is_real_socket(io)) {
+ rb_io_t *fptr;
+ GetOpenFile(io, fptr);
+ rb_io_set_nonblock(fptr);
+ }
+ else {
+ // Not meant to be a comprehensive check
+ if (!rb_respond_to(io, rb_intern("read_nonblock")) ||
+ !rb_respond_to(io, rb_intern("write_nonblock")))
+ rb_raise(rb_eTypeError, "io must be a real IO object or an IO-like "
+ "object that responds to read_nonblock and write_nonblock");
+ }
rb_ivar_set(self, id_i_io, io);
ssl = SSL_new(ctx);
@@ -1669,18 +1698,24 @@ ossl_ssl_setup(VALUE self)
{
VALUE io;
SSL *ssl;
- rb_io_t *fptr;
GetSSL(self, ssl);
if (ssl_started(ssl))
return Qtrue;
io = rb_attr_get(self, id_i_io);
- GetOpenFile(io, fptr);
- rb_io_check_readable(fptr);
- rb_io_check_writable(fptr);
- if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io))))
- ossl_raise(eSSLError, "SSL_set_fd");
+ if (is_real_socket(io)) {
+ rb_io_t *fptr;
+ GetOpenFile(io, fptr);
+ rb_io_check_readable(fptr);
+ rb_io_check_writable(fptr);
+ if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io))))
+ ossl_raise(eSSLError, "SSL_set_fd");
+ }
+ else {
+ BIO *bio = ossl_bio_new(io);
+ SSL_set_bio(ssl, bio, bio);
+ }
return Qtrue;
}
@@ -1691,6 +1726,38 @@ ossl_ssl_setup(VALUE self)
#define ssl_get_error(ssl, ret) SSL_get_error((ssl), (ret))
#endif
+static void
+check_bio_error(SSL *ssl, VALUE io, int ret)
+{
+ if (is_real_socket(io))
+ return;
+
+ BIO *bio = SSL_get_rbio(ssl);
+ int state = ossl_bio_state(bio);
+ if (!state)
+ return;
+
+ /*
+ * Operation may succeed while the underlying socket reports an error in
+ * some cases. For example, when TLS 1.3 server tries to send a
+ * NewSessionTicket on a closed socket (IOW, when the client disconnects
+ * right after finishing a handshake).
+ *
+ * According to ssl/statem/statem_srvr.c conn_is_closed(), EPIPE and
+ * ECONNRESET may be ignored.
+ *
+ * FIXME BEFORE MERGE: Currently ignoring all SystemCallError.
+ */
+ int error_code = SSL_get_error(ssl, ret);
+ if ((ret > 0 || error_code == SSL_ERROR_ZERO_RETURN || error_code == SSL_ERROR_SSL) &&
+ rb_obj_is_kind_of(rb_errinfo(), rb_eSystemCallError)) {
+ rb_set_errinfo(Qnil);
+ return;
+ }
+ ossl_clear_error();
+ rb_jump_tag(state);
+}
+
static void
write_would_block(int nonblock)
{
@@ -1729,6 +1796,11 @@ no_exception_p(VALUE opts)
static void
io_wait_writable(VALUE io)
{
+ if (!is_real_socket(io)) {
+ if (!RTEST(rb_funcallv(io, rb_intern("wait_writable"), 0, NULL)))
+ rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!");
+ return;
+ }
#ifdef HAVE_RB_IO_MAYBE_WAIT
if (!rb_io_maybe_wait_writable(errno, io, RUBY_IO_TIMEOUT_DEFAULT)) {
rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!");
@@ -1743,6 +1815,11 @@ io_wait_writable(VALUE io)
static void
io_wait_readable(VALUE io)
{
+ if (!is_real_socket(io)) {
+ if (!RTEST(rb_funcallv(io, rb_intern("wait_readable"), 0, NULL)))
+ rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!");
+ return;
+ }
#ifdef HAVE_RB_IO_MAYBE_WAIT
if (!rb_io_maybe_wait_readable(errno, io, RUBY_IO_TIMEOUT_DEFAULT)) {
rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!");
@@ -1767,8 +1844,10 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts)
GetSSL(self, ssl);
VALUE io = rb_attr_get(self, id_i_io);
+
for (;;) {
ret = func(ssl);
+ check_bio_error(ssl, io, ret);
cb_state = rb_attr_get(self, ID_callback_state);
if (!NIL_P(cb_state)) {
@@ -1963,6 +2042,8 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock)
rb_str_locktmp(str);
for (;;) {
int nread = SSL_read(ssl, RSTRING_PTR(str), ilen);
+ check_bio_error(ssl, io, nread);
+
switch (ssl_get_error(ssl, nread)) {
case SSL_ERROR_NONE:
rb_str_unlocktmp(str);
@@ -2067,6 +2148,8 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts)
for (;;) {
int nwritten = SSL_write(ssl, RSTRING_PTR(tmp), num);
+ check_bio_error(ssl, io, nwritten);
+
switch (ssl_get_error(ssl, nwritten)) {
case SSL_ERROR_NONE:
return INT2NUM(nwritten);
@@ -2144,7 +2227,15 @@ ossl_ssl_stop(VALUE self)
GetSSL(self, ssl);
if (!ssl_started(ssl))
return Qnil;
+
ret = SSL_shutdown(ssl);
+
+ /* XXX: Suppressing errors from the underlying socket */
+ VALUE io = rb_attr_get(self, id_i_io);
+ BIO *bio = SSL_get_rbio(ssl);
+ if (!is_real_socket(io) && ossl_bio_state(bio))
+ rb_set_errinfo(Qnil);
+
if (ret == 1) /* Have already received close_notify */
return Qnil;
if (ret == 0) /* Sent close_notify, but we don't wait for reply */
diff --git a/test/openssl/test_pair.rb b/test/openssl/test_pair.rb
index 10942191d..1664c00c8 100644
--- a/test/openssl/test_pair.rb
+++ b/test/openssl/test_pair.rb
@@ -67,6 +67,32 @@ def create_tcp_client(host, port)
end
end
+module OpenSSL::SSLPairIOish
+ include OpenSSL::SSLPairM
+
+ def create_tcp_server(host, port)
+ Addrinfo.tcp(host, port).listen
+ end
+
+ class TCPSocketWrapper
+ def initialize(io) @io = io end
+ def read_nonblock(*args, **kwargs) @io.read_nonblock(*args, **kwargs) end
+ def write_nonblock(*args, **kwargs) @io.write_nonblock(*args, **kwargs) end
+ def wait_readable() @io.wait_readable end
+ def wait_writable() @io.wait_writable end
+ def flush() @io.flush end
+ def close() @io.close end
+ def closed?() @io.closed? end
+
+ # Only used within test_pair.rb
+ def write(*args) @io.write(*args) end
+ end
+
+ def create_tcp_client(host, port)
+ TCPSocketWrapper.new(Addrinfo.tcp(host, port).connect)
+ end
+end
+
module OpenSSL::TestEOF1M
def open_file(content)
ssl_pair { |s1, s2|
@@ -518,6 +544,12 @@ class OpenSSL::TestEOF1LowlevelSocket < OpenSSL::TestCase
include OpenSSL::TestEOF1M
end
+class OpenSSL::TestEOF1IOish < OpenSSL::TestCase
+ include OpenSSL::TestEOF
+ include OpenSSL::SSLPairIOish
+ include OpenSSL::TestEOF1M
+end
+
class OpenSSL::TestEOF2 < OpenSSL::TestCase
include OpenSSL::TestEOF
include OpenSSL::SSLPair
@@ -530,6 +562,12 @@ class OpenSSL::TestEOF2LowlevelSocket < OpenSSL::TestCase
include OpenSSL::TestEOF2M
end
+class OpenSSL::TestEOF2IOish < OpenSSL::TestCase
+ include OpenSSL::TestEOF
+ include OpenSSL::SSLPairIOish
+ include OpenSSL::TestEOF2M
+end
+
class OpenSSL::TestPair < OpenSSL::TestCase
include OpenSSL::SSLPair
include OpenSSL::TestPairM
@@ -540,4 +578,9 @@ class OpenSSL::TestPairLowlevelSocket < OpenSSL::TestCase
include OpenSSL::TestPairM
end
+class OpenSSL::TestPairIOish < OpenSSL::TestCase
+ include OpenSSL::SSLPairIOish
+ include OpenSSL::TestPairM
+end
+
end
diff --git a/test/openssl/test_ssl.rb b/test/openssl/test_ssl.rb
index f011e881e..315addc28 100644
--- a/test/openssl/test_ssl.rb
+++ b/test/openssl/test_ssl.rb
@@ -4,17 +4,6 @@
if defined?(OpenSSL::SSL)
class OpenSSL::TestSSL < OpenSSL::SSLTestCase
- def test_bad_socket
- bad_socket = Struct.new(:sync).new
- assert_raise TypeError do
- socket = OpenSSL::SSL::SSLSocket.new bad_socket
- # if the socket is not a T_FILE, `connect` will segv because it tries
- # to get the underlying file descriptor but the API it calls assumes
- # the object type is T_FILE
- socket.connect
- end
- end
-
def test_ctx_options
ctx = OpenSSL::SSL::SSLContext.new
@@ -141,6 +130,65 @@ def test_socket_close_write
end
end
+ def test_synthetic_io_sanity_check
+ obj = Object.new
+ assert_raise_with_message(TypeError, /read_nonblock/) { OpenSSL::SSL::SSLSocket.new(obj) }
+
+ obj = Object.new
+ obj.define_singleton_method(:read_nonblock) { |*args, **kwargs| }
+ obj.define_singleton_method(:write_nonblock) { |*args, **kwargs| }
+ assert_nothing_raised { OpenSSL::SSL::SSLSocket.new(obj) }
+ end
+
+ def test_synthetic_io
+ start_server do |port|
+ tcp = TCPSocket.new("127.0.0.1", port)
+ obj = Object.new
+ obj.define_singleton_method(:read_nonblock) { |maxlen, exception:|
+ tcp.read_nonblock(maxlen, exception: exception) }
+ obj.define_singleton_method(:write_nonblock) { |str, exception:|
+ tcp.write_nonblock(str, exception: exception) }
+ obj.define_singleton_method(:wait_readable) { tcp.wait_readable }
+ obj.define_singleton_method(:wait_writable) { tcp.wait_writable }
+ obj.define_singleton_method(:flush) { tcp.flush }
+ obj.define_singleton_method(:closed?) { tcp.closed? }
+
+ ssl = OpenSSL::SSL::SSLSocket.new(obj)
+ assert_same obj, ssl.to_io
+
+ ssl.connect
+ ssl.puts "abc"; assert_equal "abc\n", ssl.gets
+ ensure
+ ssl&.close
+ tcp&.close
+ end
+ end
+
+ def test_synthetic_io_write_nonblock_exception
+ start_server(ignore_listener_error: true) do |port|
+ tcp = TCPSocket.new("127.0.0.1", port)
+ obj = Object.new
+ [:read_nonblock, :wait_readable, :wait_writable, :flush, :closed?].each do |name|
+ obj.define_singleton_method(name) { |*args, **kwargs|
+ tcp.__send__(name, *args, **kwargs) }
+ end
+
+ # SSLSocket#connect calls write_nonblock at least twice: ClientHello and Finished
+ # Let's break the second call
+ called = 0
+ obj.define_singleton_method(:write_nonblock) { |*args, **kwargs|
+ raise "foo" if (called += 1) == 2
+ tcp.write_nonblock(*args, **kwargs)
+ }
+
+ ssl = OpenSSL::SSL::SSLSocket.new(obj)
+ assert_raise_with_message(RuntimeError, "foo") { ssl.connect }
+ ensure
+ ssl&.close
+ tcp&.close
+ end
+ end
+
def test_add_certificate
ctx_proc = -> ctx {
# Unset values set by start_server