Skip to content

Commit

Permalink
Reduce OpenSSL::Buffering#do_write overhead
Browse files Browse the repository at this point in the history
[Bug #20972]

The `rb_str_new_freeze` was added in #452
to better handle concurrent use of a Socket, but SSL sockets can't be used
concurrently AFAIK, so we might as well just error cleanly.

By using `rb_str_locktmp` we can ensure attempts at concurrent write
will raise an error, be we avoid causing a copy of the bytes.

We also use the newer `String#append_as_bytes` method when available
to save on some more copies.

Co-Authored-By: [email protected]
  • Loading branch information
byroot committed Dec 27, 2024
1 parent f4e7c4b commit 0980748
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 28 deletions.
29 changes: 24 additions & 5 deletions ext/openssl/ossl_ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -2054,28 +2054,32 @@ ossl_ssl_read_nonblock(int argc, VALUE *argv, VALUE self)
}

static VALUE
ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts)
ossl_ssl_write_internal_safe(VALUE _args)
{
VALUE *args = (VALUE*)_args;
VALUE self = args[0];
VALUE str = args[1];
VALUE opts = args[2];

SSL *ssl;
rb_io_t *fptr;
int num, nonblock = opts != Qfalse;
VALUE tmp, cb_state;
VALUE cb_state;

GetSSL(self, ssl);
if (!ssl_started(ssl))
rb_raise(eSSLError, "SSL session is not started yet");

tmp = rb_str_new_frozen(StringValue(str));
VALUE io = rb_attr_get(self, id_i_io);
GetOpenFile(io, fptr);

/* SSL_write(3ssl) manpage states num == 0 is undefined */
num = RSTRING_LENINT(tmp);
num = RSTRING_LENINT(str);
if (num == 0)
return INT2FIX(0);

for (;;) {
int nwritten = SSL_write(ssl, RSTRING_PTR(tmp), num);
int nwritten = SSL_write(ssl, RSTRING_PTR(str), num);

cb_state = rb_attr_get(self, ID_callback_state);
if (!NIL_P(cb_state)) {
Expand Down Expand Up @@ -2116,6 +2120,21 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts)
}
}

static VALUE
ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts)
{
VALUE args[3] = {self, str, opts};
int state;
str = rb_str_locktmp(StringValue(str));
VALUE result = rb_protect(ossl_ssl_write_internal_safe, (VALUE)args, &state);
rb_str_unlocktmp(str);

if (state) {
rb_jump_tag(state);
}
return result;
}

/*
* call-seq:
* ssl.syswrite(string) => Integer
Expand Down
48 changes: 25 additions & 23 deletions lib/openssl/buffering.rb
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,18 @@ module OpenSSL::Buffering
include Enumerable

# A buffer which will retain binary encoding.
class Buffer < String
BINARY = Encoding::BINARY

def initialize
super

force_encoding(BINARY)
end

def << string
if string.encoding == BINARY
super(string)
else
super(string.b)
if String.method_defined?(:append_as_bytes)
Buffer = String
else
class Buffer < String
def append_as_bytes(string)
if string.encoding == Encoding::BINARY || string.ascii_only?
self << string
else
self << string.b
end
end

return self
end

alias concat <<
end

##
Expand Down Expand Up @@ -352,22 +344,32 @@ def eof?

def do_write(s)
@wbuffer = Buffer.new unless defined? @wbuffer
@wbuffer << s
@wbuffer.force_encoding(Encoding::BINARY)
@wbuffer.append_as_bytes(s)

@sync ||= false
buffer_size = @wbuffer.size
buffer_size = @wbuffer.bytesize
if @sync or buffer_size > BLOCK_SIZE
nwrote = 0
begin
while nwrote < buffer_size do
begin
nwrote += syswrite(@wbuffer[nwrote, buffer_size - nwrote])
chunk = if nwrote > 0
@wbuffer.byteslice(nwrote..-1)
else
@wbuffer
end

nwrote += syswrite(chunk)
rescue Errno::EAGAIN
retry
end
end
ensure
@wbuffer[0, nwrote] = ""
if nwrote < @wbuffer.bytesize
@wbuffer[0, nwrote] = ""
else
@wbuffer.clear
end
end
end
end
Expand Down

0 comments on commit 0980748

Please sign in to comment.