diff --git a/ext/openssl/ossl_ssl.c b/ext/openssl/ossl_ssl.c index 2525d0c87..37d5837fd 100644 --- a/ext/openssl/ossl_ssl.c +++ b/ext/openssl/ossl_ssl.c @@ -2054,8 +2054,13 @@ ossl_ssl_read_nonblock(int argc, VALUE *argv, VALUE self) } static VALUE -ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts) +ossl_ssl_write_internal_safe(VALUE _args) { + VALUE *args = (VALUE*)_args; + VALUE self = args[0]; + VALUE str = args[1]; + VALUE opts = args[2]; + SSL *ssl; rb_io_t *fptr; int num, nonblock = opts != Qfalse; @@ -2116,6 +2121,21 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts) } } +static VALUE +ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts) +{ + VALUE args[3] = {self, str, opts}; + int state; + str = rb_str_locktmp(StringValue(str)); + VALUE result = rb_protect(ossl_ssl_write_internal_safe, (VALUE)args, &state); + rb_str_unlocktmp(str); + + if (state) { + rb_jump_tag(state); + } + return result; +} + /* * call-seq: * ssl.syswrite(string) => Integer diff --git a/lib/openssl/buffering.rb b/lib/openssl/buffering.rb index 85f593af0..7ecf04c90 100644 --- a/lib/openssl/buffering.rb +++ b/lib/openssl/buffering.rb @@ -23,26 +23,18 @@ module OpenSSL::Buffering include Enumerable # A buffer which will retain binary encoding. - class Buffer < String - BINARY = Encoding::BINARY - - def initialize - super - - force_encoding(BINARY) - end - - def << string - if string.encoding == BINARY - super(string) - else - super(string.b) + if String.method_defined?(:append_as_bytes) + Buffer = String + else + class Buffer < String + def append_as_bytes(string) + if string.encoding == Encoding::BINARY || string.ascii_only? + self << string + else + self << string.b + end end - - return self end - - alias concat << end ## @@ -352,22 +344,32 @@ def eof? def do_write(s) @wbuffer = Buffer.new unless defined? @wbuffer - @wbuffer << s - @wbuffer.force_encoding(Encoding::BINARY) + @wbuffer.append_as_bytes(s) + @sync ||= false - buffer_size = @wbuffer.size + buffer_size = @wbuffer.bytesize if @sync or buffer_size > BLOCK_SIZE nwrote = 0 begin while nwrote < buffer_size do begin - nwrote += syswrite(@wbuffer[nwrote, buffer_size - nwrote]) + chunk = if nwrote > 0 + @wbuffer.byteslice(nwrote..-1) + else + @wbuffer + end + + nwrote += syswrite(chunk) rescue Errno::EAGAIN retry end end ensure - @wbuffer[0, nwrote] = "" + if nwrote < @wbuffer.bytesize + @wbuffer[0, nwrote] = "" + else + @wbuffer.clear + end end end end