diff --git a/ext/openssl/ossl_bio.c b/ext/openssl/ossl_bio.c index 42833d901..b4e029592 100644 --- a/ext/openssl/ossl_bio.c +++ b/ext/openssl/ossl_bio.c @@ -40,3 +40,108 @@ ossl_membio2str(BIO *bio) return ret; } + +int +ossl_membio_sock_read(BIO* bio, VALUE io) { + VALUE nonblock_kwargs = rb_hash_new(); + rb_hash_aset(nonblock_kwargs, ID2SYM(rb_intern("exception")), Qfalse); + + printf("reading...\n"); + + VALUE fargs[] = { INT2NUM(4096), nonblock_kwargs }; + VALUE ret = rb_funcallv_public_kw(io, rb_intern("read_nonblock"), 2, fargs, RB_PASS_KEYWORDS); + printf("just read...\n"); + int len; + char *bstr; + + if (RB_TYPE_P(ret, T_STRING)) { + len = RSTRING_LENINT(ret); + bstr = RSTRING_PTR(ret); + printf("read the nonblock: %d...\n", len); + } + else if (ret == ID2SYM(rb_intern("wait_readable"))) { + // BIO_set_retry_read(bio); + return SSL_ERROR_WANT_READ; + } + else if (ret == ID2SYM(rb_intern("wait_writable"))) { + // BIO_set_retry_write(bio); + return SSL_ERROR_WANT_WRITE; + } + else if (NIL_P(ret)) { + printf("fuck the nil\n"); + return SSL_ERROR_ZERO_RETURN; + } + else { + printf("elsing\n"); + rb_raise(rb_eTypeError, "write_nonblock must return an Integer, " + ":wait_readable, or :wait_writable"); + } + + while (len > 0) { + int n = BIO_write(bio, bstr, len); + BIO_clear_retry_flags(bio); + + if (n<=0) + return SSL_ERROR_SYSCALL; // unrecoverable + + bstr += n; + len -= n; + + // // finish handshake if required + // if (!SSL_is_init_finished(client.ssl)) { + // if (do_ssl_handshake() == SSLSTATUS_FAIL) + // return SSL_ERROR_SYSCALL; + // if (!SSL_is_init_finished(client.ssl)) + // // assume there are bytes missing + // return SSL_ERROR_WANT_READ; + // } + } + return SSL_ERROR_NONE; +} + +int +ossl_membio_sock_write(BIO* bio, VALUE io) { + char buf[4096]; + char *p = buf; + + int n = BIO_read(bio, p, 4096); + BIO_clear_retry_flags(bio); + if (n <= 0) { + if (!BIO_should_retry(bio)) + // TODO: raise exception + return -1; + } + + printf("writing to bio 2: %d\n", n); + + VALUE nonblock_kwargs = rb_hash_new(); + rb_hash_aset(nonblock_kwargs, ID2SYM(rb_intern("exception")), Qfalse); + + VALUE fargs[] = { rb_str_new_static(buf, n), nonblock_kwargs }; + + // rb_io_write(rb_stdout ,rb_sprintf("%s\n", RSTRING_PTR(*biobuf))); + // rb_p(*biobuf); + VALUE ret = rb_funcallv_public_kw(io, rb_intern("write_nonblock"), 2, fargs, RB_PASS_KEYWORDS); + + if (RB_INTEGER_TYPE_P(ret)) { + // TODO: resize buffer + return SSL_ERROR_NONE; + } + else if (ret == ID2SYM(rb_intern("wait_readable"))) { + printf("wred\n"); + // BIO_set_retry_read(bio); + return SSL_ERROR_WANT_READ; + } + else if (ret == ID2SYM(rb_intern("wait_writable"))) { + printf("wwrit\n"); + // BIO_set_retry_write(bio); + return SSL_ERROR_WANT_WRITE; + } else if (NIL_P(ret)) { + printf("closed\n"); + return SSL_ERROR_ZERO_RETURN; + } + else { + rb_raise(rb_eTypeError, "write_nonblock must return an Integer, " + ":wait_readable, or :wait_writable"); + } +} \ No newline at end of file diff --git a/ext/openssl/ossl_bio.h b/ext/openssl/ossl_bio.h index da68c5e5a..a07b8572e 100644 --- a/ext/openssl/ossl_bio.h +++ b/ext/openssl/ossl_bio.h @@ -12,5 +12,7 @@ BIO *ossl_obj2bio(volatile VALUE *); VALUE ossl_membio2str(BIO*); +int ossl_membio_sock_read(BIO *, VALUE); +int ossl_membio_sock_write(BIO * , VALUE); #endif diff --git a/ext/openssl/ossl_ssl.c b/ext/openssl/ossl_ssl.c index c63f5868f..0f3f0ce67 100644 --- a/ext/openssl/ossl_ssl.c +++ b/ext/openssl/ossl_ssl.c @@ -57,8 +57,15 @@ static ID id_i_io, id_i_context, id_i_hostname; static int ossl_ssl_ex_vcb_idx; static int ossl_ssl_ex_ptr_idx; +static int ossl_ssl_ex_rbio_idx; +static int ossl_ssl_ex_wbio_idx; static int ossl_sslctx_ex_ptr_idx; + +int IsSock(VALUE io) { + return RB_TYPE_P(io, T_FILE); +} + static void ossl_sslctx_mark(void *ptr) { @@ -1639,7 +1646,7 @@ ossl_ssl_initialize(int argc, VALUE *argv, VALUE self) if (rb_respond_to(io, rb_intern("nonblock="))) rb_funcall(io, rb_intern("nonblock="), 1, Qtrue); - Check_Type(io, T_FILE); + // Check_Type(io, T_FILE); rb_ivar_set(self, id_i_io, io); ssl = SSL_new(ctx); @@ -1682,11 +1689,22 @@ ossl_ssl_setup(VALUE self) return Qtrue; io = rb_attr_get(self, id_i_io); - GetOpenFile(io, fptr); - rb_io_check_readable(fptr); - rb_io_check_writable(fptr); - if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io)))) - ossl_raise(eSSLError, "SSL_set_fd"); + + if (IsSock(io)) { + GetOpenFile(io, fptr); + rb_io_check_readable(fptr); + rb_io_check_writable(fptr); + if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io)))) + ossl_raise(eSSLError, "SSL_set_fd"); + } else { + // something which quacks like an IO + // TODO: how to best ensure this from the C API?? + BIO *rbio = BIO_new(BIO_s_mem()); + BIO *wbio = BIO_new(BIO_s_mem()); + SSL_set_bio(ssl, rbio, wbio); + SSL_set_ex_data(ssl, ossl_ssl_ex_rbio_idx, (void *)rbio); + SSL_set_ex_data(ssl, ossl_ssl_ex_wbio_idx, (void *)wbio); + } return Qtrue; } @@ -1789,6 +1807,15 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts) GetSSL(self, ssl); VALUE io = rb_attr_get(self, id_i_io); + + int is_sock = IsSock(io); + BIO *rbio, *wbio; + + if (!is_sock) { + rbio = (BIO *)SSL_get_ex_data(ssl, ossl_ssl_ex_rbio_idx); + wbio = (BIO *)SSL_get_ex_data(ssl, ossl_ssl_ex_wbio_idx); + } + for (;;) { ret = func(ssl); @@ -1798,11 +1825,35 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts) ossl_clear_error(); rb_jump_tag(NUM2INT(cb_state)); } + printf("connect ret: %d\n", ret); if (ret > 0) break; - switch ((ret2 = ssl_get_error(ssl, ret))) { + ret2 = ssl_get_error(ssl, ret); + printf("WANT_READ: %d, WANT_WRITE. %d\n", SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE); + printf("connect ret2: %d\n", ret2); + + if (!is_sock) { + if(ret2 == SSL_ERROR_WANT_WRITE) { + ret2 = ossl_membio_sock_read(rbio, io); + if (ret2 == SSL_ERROR_NONE) { + ret2 = SSL_ERROR_WANT_WRITE; + printf("out fuckerz\n"); + break; + } + } else if (ret2 == SSL_ERROR_WANT_READ) { + ret2 = ossl_membio_sock_write(wbio, io); + if (ret2 == SSL_ERROR_NONE) { + printf("out fuckerz\n"); + break; + // continue; + } + } + } + printf("connect after is_sock ret2: %d\n", ret2); + + switch ((ret2)) { case SSL_ERROR_WANT_WRITE: if (no_exception_p(opts)) { return sym_wait_writable; } write_would_block(nonblock); @@ -1954,6 +2005,7 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) int ilen; VALUE len, str; VALUE opts = Qnil; + BIO *rbio, *wbio; if (nonblock) { rb_scan_args(argc, argv, "11:", &len, &str, &opts); @@ -1979,11 +2031,28 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) return str; VALUE io = rb_attr_get(self, id_i_io); + int is_sock = IsSock(io); + + if (!is_sock) { + rbio = (BIO *)SSL_get_ex_data(ssl, ossl_ssl_ex_rbio_idx); + wbio = (BIO *)SSL_get_ex_data(ssl, ossl_ssl_ex_wbio_idx); + } rb_str_locktmp(str); for (;;) { - int nread = ossl_ssl_read_impl(ssl, str, ilen); - switch (ssl_get_error(ssl, nread)) { + int ret, nread; + + if (!is_sock) { + ret = ossl_membio_sock_read(rbio, io); + if (ret == SSL_ERROR_NONE) { + ret = ossl_membio_sock_write(wbio, io); + } + } else { + nread = ossl_ssl_read_impl(ssl, str, ilen); + ret = ssl_get_error(ssl, nread); + } + + switch (ret) { case SSL_ERROR_NONE: rb_str_unlocktmp(str); rb_str_set_len(str, nread); @@ -2080,7 +2149,16 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts) tmp = rb_str_new_frozen(StringValue(str)); VALUE io = rb_attr_get(self, id_i_io); - GetOpenFile(io, fptr); + int is_sock = IsSock(io); + // BIO *rbio; + BIO *wbio; + + if (!is_sock) { + // rbio = (BIO *)SSL_get_ex_data(ssl, ossl_ssl_ex_rbio_idx); + wbio = (BIO *)SSL_get_ex_data(ssl, ossl_ssl_ex_wbio_idx); + } else { + GetOpenFile(io, fptr); + } /* SSL_write(3ssl) manpage states num == 0 is undefined */ num = RSTRING_LENINT(tmp); @@ -2088,8 +2166,18 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts) return INT2FIX(0); for (;;) { + + int nwritten = ossl_ssl_write_impl(ssl, tmp, num); - switch (ssl_get_error(ssl, nwritten)) { + int ret = ssl_get_error(ssl, nwritten); + + if (!is_sock) { + if (ret == SSL_ERROR_NONE) { + ret = ossl_membio_sock_write(wbio, io); + } + } + + switch (ret) { case SSL_ERROR_NONE: return INT2NUM(nwritten); case SSL_ERROR_WANT_WRITE: @@ -2623,6 +2711,12 @@ Init_ossl_ssl(void) ossl_ssl_ex_ptr_idx = SSL_get_ex_new_index(0, (void *)"ossl_ssl_ex_ptr_idx", 0, 0, 0); if (ossl_ssl_ex_ptr_idx < 0) ossl_raise(rb_eRuntimeError, "SSL_get_ex_new_index"); + ossl_ssl_ex_rbio_idx = SSL_get_ex_new_index(0, (void *)"ossl_ssl_ex_rbio_idx", 0, 0, 0); + if (ossl_ssl_ex_rbio_idx < 0) + ossl_raise(rb_eRuntimeError, "SSL_get_ex_new_index"); + ossl_ssl_ex_wbio_idx = SSL_get_ex_new_index(0, (void *)"ossl_ssl_ex_wbio_idx", 0, 0, 0); + if (ossl_ssl_ex_wbio_idx < 0) + ossl_raise(rb_eRuntimeError, "SSL_get_ex_new_index"); ossl_sslctx_ex_ptr_idx = SSL_CTX_get_ex_new_index(0, (void *)"ossl_sslctx_ex_ptr_idx", 0, 0, 0); if (ossl_sslctx_ex_ptr_idx < 0) ossl_raise(rb_eRuntimeError, "SSL_CTX_get_ex_new_index"); @@ -3152,6 +3246,11 @@ Init_ossl_ssl(void) id_npn_protocols_encoded = rb_intern_const("npn_protocols_encoded"); id_each = rb_intern_const("each"); + nonblock_kwargs = rb_hash_new(); + rb_hash_aset(nonblock_kwargs, sym_exception, Qfalse); + rb_obj_freeze(nonblock_kwargs); + rb_global_variable(&nonblock_kwargs); + #define DefIVarID(name) do \ id_i_##name = rb_intern_const("@"#name); while (0) diff --git a/ext/openssl/ossl_ssl.h b/ext/openssl/ossl_ssl.h index 535c56097..29c9de091 100644 --- a/ext/openssl/ossl_ssl.h +++ b/ext/openssl/ossl_ssl.h @@ -29,6 +29,7 @@ extern const rb_data_type_t ossl_ssl_session_type; extern VALUE mSSL; extern VALUE cSSLSocket; extern VALUE cSSLSession; +static VALUE nonblock_kwargs; void Init_ossl_ssl(void); void Init_ossl_ssl_session(void);