From f85fce88868f92f63b384c2f5bb70d7fd880bf93 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Fri, 8 Dec 2023 16:33:00 -0500 Subject: [PATCH 1/5] Implement unary HTTP calls with retry support Changes the internal duplexHTTPCall to switch on non streaming client requests to block and wait for the response. This avoids the need to convert the reader to a writer with io.Pipe and the go routine to run asynchronously. On unary requests we are now able to set the `Content-Length` header and the `GetBody` function for retries. To safely reuse the payload buffer a new type `payloadCloser` is added to implement the required HTTP body semantics. On receiving a response the request may still be read up and until the response is body is closed. To ensure the request body is safe to releasee we wait for a complete read or close on the request body. Retries may read, close or rewind the body multiple times before a response is returned. --- connect_ext_test.go | 16 +- duplex_http_call.go | 222 +++++++++++++++++++++------ duplex_http_call_test.go | 122 +++++++++++++++ envelope.go | 2 +- internal/memhttp/memhttptest/http.go | 2 +- 5 files changed, 311 insertions(+), 53 deletions(-) create mode 100644 duplex_http_call_test.go diff --git a/connect_ext_test.go b/connect_ext_test.go index 2cb04a07..9091a17f 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -434,8 +434,9 @@ func TestServer(t *testing.T) { pingServer{checkMetadata: true}, ) errorWriter := connect.NewErrorWriter() - // Add some net/http middleware to the ping service so we can also exercise ErrorWriter. + // Add net/http middleware to the ping service to evaluate HTTP state. mux.Handle(pingRoute, http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + // Exercise ErrorWriter for HTTP middleware errors. if request.Header.Get(clientMiddlewareErrorHeader) != "" { defer request.Body.Close() if _, err := io.Copy(io.Discard, request.Body); err != nil { @@ -449,6 +450,19 @@ func TestServer(t *testing.T) { } return } + // Check Content-Length is set correctly. + switch request.URL.Path { + case pingv1connect.PingServicePingProcedure, + pingv1connect.PingServiceFailProcedure, + pingv1connect.PingServiceCountUpProcedure: + if request.ContentLength < 0 { + t.Errorf("%s: expected Content-Length >= 0, got %d", request.URL.Path, request.ContentLength) + } + default: + if request.ContentLength > 0 { + t.Errorf("%s: expected Content-Length -1 or 0, got %d", request.URL.Path, request.ContentLength) + } + } pingHandler.ServeHTTP(response, request) })) diff --git a/duplex_http_call.go b/duplex_http_call.go index 5833b39c..fdfbb2a2 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -21,6 +21,7 @@ import ( "io" "net/http" "net/url" + "sync" "sync/atomic" ) @@ -36,10 +37,8 @@ type duplexHTTPCall struct { onRequestSend func(*http.Request) validateResponse func(*http.Response) *Error - // We'll use a pipe as the request body. We hand the read side of the pipe to - // net/http, and we write to the write side (naturally). The two ends are - // safe to use concurrently. - requestBodyReader *io.PipeReader + // io.Pipe is used to implement the request body for client streaming calls. + // If the request is unary, requestBodyWriter is nil. requestBodyWriter *io.PipeWriter // requestSent ensures we only send the request once. @@ -65,7 +64,6 @@ func newDuplexHTTPCall( // Request. This ensures if a transport out of our control wants // to mutate the req.URL, we don't feel the effects of it. url = cloneURL(url) - pipeReader, pipeWriter := io.Pipe() // This is mirroring what http.NewRequestContext did, but // using an already parsed url.URL object, rather than a string @@ -74,30 +72,40 @@ func newDuplexHTTPCall( // NewRequestContext and doesn't effect the actual version // being transmitted. request := (&http.Request{ - Method: http.MethodPost, - URL: url, - Header: header, - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Body: pipeReader, - Host: url.Host, + Method: http.MethodPost, + URL: url, + Header: header, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Body: http.NoBody, + ContentLength: 0, + Host: url.Host, }).WithContext(ctx) return &duplexHTTPCall{ - ctx: ctx, - httpClient: httpClient, - streamType: spec.StreamType, - requestBodyReader: pipeReader, - requestBodyWriter: pipeWriter, - request: request, - responseReady: make(chan struct{}), + ctx: ctx, + httpClient: httpClient, + streamType: spec.StreamType, + request: request, + responseReady: make(chan struct{}), } } // Send sends a message to the server. -func (d *duplexHTTPCall) Send(payload messsagePayload) (int64, error) { - isFirst := d.ensureRequestMade() - // Before we send any data, check if the context has been canceled. +func (d *duplexHTTPCall) Send(payload messagePayload) (int64, error) { + if d.streamType&StreamTypeClient == 0 { + return d.sendUnary(payload) + } + isFirst := d.requestSent.CompareAndSwap(false, true) + if isFirst { + // This is the first time we're sending a message to the server. + // We need to send the request headers and start the request. + pipeReader, pipeWriter := io.Pipe() + d.requestBodyWriter = pipeWriter + d.request.Body = pipeReader + d.request.ContentLength = -1 + go d.makeRequest() // concurrent request + } if err := d.ctx.Err(); err != nil { return 0, wrapIfContextError(err) } @@ -113,18 +121,55 @@ func (d *duplexHTTPCall) Send(payload messsagePayload) (int64, error) { // Signal that the stream is closed with the more-typical io.EOF instead of // io.ErrClosedPipe. This makes it easier for protocol-specific wrappers to // match grpc-go's behavior. - return bytesWritten, io.EOF + err = io.EOF } return bytesWritten, err } +func (d *duplexHTTPCall) sendUnary(payload messagePayload) (int64, error) { + // Unary messages are sent as a single HTTP request. We don't need to use a + // pipe for the request body and we don't need to send headers separately. + if !d.requestSent.CompareAndSwap(false, true) { + return 0, fmt.Errorf("request already sent") + } + payloadLength := int64(payload.Len()) + if payloadLength > 0 { + // Build the request body from the payload. + payloadBody := newPayloadCloser(payload) + d.request.Body = payloadBody + d.request.ContentLength = payloadLength + d.request.GetBody = func() (io.ReadCloser, error) { + if !payloadBody.Rewind() { + return nil, fmt.Errorf("payload cannot be retried") + } + return payloadBody, nil + } + // Wait on the payloadBody to be completly read or closed before + // returning from Send. This ensures that the payload can be reused + // after Send returns. See [http.RoundTripper] for more details. + defer payloadBody.Wait() + } else { + d.request.GetBody = func() (io.ReadCloser, error) { + return http.NoBody, nil + } + } + d.makeRequest() // synchronous request + if err := d.ctx.Err(); err != nil { + return 0, wrapIfContextError(err) + } + return payloadLength, nil +} + // Close the request body. Callers *must* call CloseWrite before Read when // using HTTP/1.x. func (d *duplexHTTPCall) CloseWrite() error { // Even if Write was never called, we need to make an HTTP request. This // ensures that we've sent any headers to the server and that we have an HTTP // response to read from. - d.ensureRequestMade() + if d.requestSent.CompareAndSwap(false, true) { + go d.makeRequest() + return nil + } // The user calls CloseWrite to indicate that they're done sending data. It's // safe to close the write side of the pipe while net/http is reading from // it. @@ -136,7 +181,10 @@ func (d *duplexHTTPCall) CloseWrite() error { // forever. To make sure users don't have to worry about this, the generated // code for unary, client streaming, and server streaming RPCs must call // CloseWrite automatically rather than requiring the user to do it. - return d.requestBodyWriter.Close() + if d.requestBodyWriter != nil { + return d.requestBodyWriter.Close() + } + return d.request.Body.Close() } // Header returns the HTTP request headers. @@ -171,9 +219,6 @@ func (d *duplexHTTPCall) Read(data []byte) (int, error) { if err := d.ctx.Err(); err != nil { return 0, wrapIfContextError(err) } - if d.response == nil { - return 0, fmt.Errorf("nil response from %v", d.request.URL) - } n, err := d.response.Body.Read(data) return n, wrapIfRSTError(err) } @@ -233,17 +278,6 @@ func (d *duplexHTTPCall) BlockUntilResponseReady() error { return d.responseErr } -// ensureRequestMade sends the request headers and starts the response stream. -// It is not safe to call this concurrently. Write and CloseWrite call this but -// ensure that they're not called concurrently. -func (d *duplexHTTPCall) ensureRequestMade() (isFirst bool) { - if d.requestSent.CompareAndSwap(false, true) { - go d.makeRequest() - return true - } - return false -} - func (d *duplexHTTPCall) makeRequest() { // This runs concurrently with Write and CloseWrite. Read and CloseRead wait // on d.responseReady, so we can't race with them. @@ -253,7 +287,6 @@ func (d *duplexHTTPCall) makeRequest() { if host := d.request.Header.Get(headerHost); len(host) > 0 { d.request.Host = host } - if d.onRequestSend != nil { d.onRequestSend(d.request) } @@ -272,7 +305,7 @@ func (d *duplexHTTPCall) makeRequest() { err = NewError(CodeUnavailable, err) } d.responseErr = err - d.requestBodyWriter.Close() + _ = d.CloseWrite() return } // We've got a response. We can now read from the response body. @@ -280,7 +313,7 @@ func (d *duplexHTTPCall) makeRequest() { d.response = response if err := d.validateResponse(response); err != nil { d.responseErr = err - d.requestBodyWriter.Close() + _ = d.CloseWrite() return } if (d.streamType&StreamTypeBidi) == StreamTypeBidi && response.ProtoMajor < 2 { @@ -293,13 +326,13 @@ func (d *duplexHTTPCall) makeRequest() { response.ProtoMajor, response.ProtoMinor, ) - d.requestBodyWriter.Close() + _ = d.CloseWrite() } } -// messsagePayload is a sized and seekable message payload. The interface is -// implemented by [*bytes.Reader] and *envelope. -type messsagePayload interface { +// messagePayload is a sized and seekable message payload. The interface is +// implemented by [*bytes.Reader] and *envelope. Reads must be non-blocking. +type messagePayload interface { io.Reader io.WriterTo io.Seeker @@ -310,7 +343,7 @@ type messsagePayload interface { // to the server. type nopPayload struct{} -var _ messsagePayload = nopPayload{} +var _ messagePayload = nopPayload{} func (nopPayload) Read([]byte) (int, error) { return 0, io.EOF @@ -328,7 +361,7 @@ func (nopPayload) Len() int { // messageSender sends a message payload. The interface is implemented by // [*duplexHTTPCall] and writeSender. type messageSender interface { - Send(messsagePayload) (int64, error) + Send(messagePayload) (int64, error) } // writeSender is a sender that writes to an [io.Writer]. Useful for wrapping @@ -339,7 +372,7 @@ type writeSender struct { var _ messageSender = writeSender{} -func (w writeSender) Send(payload messsagePayload) (int64, error) { +func (w writeSender) Send(payload messagePayload) (int64, error) { return payload.WriteTo(w.writer) } @@ -356,3 +389,92 @@ func cloneURL(oldURL *url.URL) *url.URL { } return newURL } + +// payloadCloser is an [io.ReadCloser] that wraps a messagePayload. It's used to +// implement the request body for unary calls. To safely reuse the buffer +// call Wait after the response is received to ensure the payload has been +// drained or closed. After Wait, the payload cannot be rewound. It's safe to +// call Close multiple times. +type payloadCloser struct { + mu sync.Mutex + cond sync.Cond + payload messagePayload + isDone bool // true if the payload has been fully read +} + +func newPayloadCloser(payload messagePayload) *payloadCloser { + closer := &payloadCloser{ + payload: payload, + } + closer.cond.L = &closer.mu + return closer +} + +// Read implements [io.Reader], on error it signals that the payload has been +// fully read. +func (p *payloadCloser) Read(dst []byte) (readN int, err error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.payload == nil { + return 0, io.EOF + } + readN, err = p.payload.Read(dst) + if err != nil || p.payload.Len() == 0 { + p.completeWithLock() + } + return readN, err +} + +// WriteTo implements [io.WriterTo]. It signals that the payload has been fully +// read. +func (p *payloadCloser) WriteTo(dst io.Writer) (int64, error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.isDone { + return 0, nil + } + n, err := p.payload.WriteTo(dst) + p.completeWithLock() + return n, err +} + +// Close implements [io.Closer]. It signals completion of the payload. +func (p *payloadCloser) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + p.completeWithLock() + return nil +} + +// Rewind rewinds the payload to the beginning. It returns false if the +// payload has been discarded from a previous call to Wait. +func (p *payloadCloser) Rewind() bool { + p.mu.Lock() + defer p.mu.Unlock() + if p.payload == nil { + return false + } + if _, err := p.payload.Seek(0, io.SeekStart); err != nil { + return false + } + p.isDone = false + return true +} + +// Wait blocks until the payload has been fully read or closed. After Wait, the +// payload is discarded and cannot be rewound. It's then safe to reuse the +// payload. +func (p *payloadCloser) Wait() { + p.mu.Lock() + for !p.isDone { + p.cond.Wait() + } + p.payload = nil + p.mu.Unlock() +} + +// completeWithLock signals that the payload has been fully read. +func (p *payloadCloser) completeWithLock() { + p.isDone = true + p.cond.Broadcast() +} diff --git a/duplex_http_call_test.go b/duplex_http_call_test.go new file mode 100644 index 00000000..13643e49 --- /dev/null +++ b/duplex_http_call_test.go @@ -0,0 +1,122 @@ +// Copyright 2021-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connect + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "net/url" + "sync" + "testing" + + "connectrpc.com/connect/internal/assert" +) + +// TestHTTPCallGetBody tests that the client is able to retry requests on +// connection close errors. It will initialize a closing handler and ensure +// http.Request.GetBody is successfully called to replay the request. +func TestHTTPCallGetBody(t *testing.T) { + t.Parallel() + handler := http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + // The "Connection: close" header is turned into a GOAWAY frame by the http2 server. + responseWriter.Header().Add("Connection", "close") + _, _ = io.Copy(responseWriter, request.Body) + _ = request.Body.Close() + }) + // Must use httptest for this test. + server := httptest.NewUnstartedServer(handler) + server.EnableHTTP2 = true + server.StartTLS() + t.Cleanup(server.Close) + bufferPool := newBufferPool() + serverURL, _ := url.Parse(server.URL) + errGetBodyCalled := errors.New("getBodyCalled") // sentinel error + caller := func(size int) error { + call := newDuplexHTTPCall( + context.Background(), + server.Client(), + serverURL, + Spec{StreamType: StreamTypeUnary}, + http.Header{}, + ) + getBodyCalled := false + call.onRequestSend = func(*http.Request) { + getBody := call.request.GetBody + call.request.GetBody = func() (io.ReadCloser, error) { + getBodyCalled = true + rdcloser, err := getBody() + assert.Nil(t, err) + return rdcloser, err + } + } + // SetValidateResponse must be set. + call.SetValidateResponse(func(*http.Response) *Error { + return nil + }) + buf := bufferPool.Get() + defer bufferPool.Put(buf) + buf.Write(make([]byte, size)) + _, err := call.Send(bytes.NewReader(buf.Bytes())) + assert.Nil(t, err) + assert.Nil(t, call.CloseWrite()) + buf.Reset() + _, err = io.Copy(buf, call) + assert.Nil(t, err) + assert.Equal(t, buf.Len(), size) + if getBodyCalled { + return errGetBodyCalled + } + return nil + } + type work struct { + size int + errs chan error + } + numWorkers := 2 + workChan := make(chan work) + wg := sync.WaitGroup{} + wg.Add(numWorkers) + for i := 0; i < numWorkers; i++ { + go func() { + for work := range workChan { + work.errs <- caller(work.size) + } + wg.Done() + }() + } + for i, gotGetBody := 0, false; !gotGetBody; i++ { + errs := make([]chan error, numWorkers) + for i := 0; i < numWorkers; i++ { + errs[i] = make(chan error, 1) + workChan <- work{size: 512, errs: errs[i]} + } + t.Log("waiting", i) + for _, errChan := range errs { + if err := <-errChan; err != nil { + if errors.Is(err, errGetBodyCalled) { + gotGetBody = true + } else { + t.Fatal(err) + } + } + } + } + close(workChan) + wg.Wait() +} diff --git a/envelope.go b/envelope.go index c9f67b36..6d0e6255 100644 --- a/envelope.go +++ b/envelope.go @@ -45,7 +45,7 @@ type envelope struct { offset int64 } -var _ messsagePayload = (*envelope)(nil) +var _ messagePayload = (*envelope)(nil) func (e *envelope) IsSet(flag uint8) bool { return e.Flags&flag == flag diff --git a/internal/memhttp/memhttptest/http.go b/internal/memhttp/memhttptest/http.go index f6d90c55..5d797e4b 100644 --- a/internal/memhttp/memhttptest/http.go +++ b/internal/memhttp/memhttptest/http.go @@ -37,7 +37,7 @@ func NewServer(tb testing.TB, handler http.Handler, opts ...memhttp.Option) *mem server := memhttp.NewServer(handler, opts...) tb.Cleanup(func() { if err := server.Cleanup(); err != nil { - tb.Error(err) + tb.Errorf("shutdown failed: %v", err) } }) return server From 35dbabb3b4fbdd735aef3312a3ef351ace134657 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 11 Dec 2023 15:57:42 -0500 Subject: [PATCH 2/5] Add GetBody for all http.NoBody requests Ensure we set GetBody on all zero length requests. --- duplex_http_call.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/duplex_http_call.go b/duplex_http_call.go index fdfbb2a2..e1b21859 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -79,6 +79,7 @@ func newDuplexHTTPCall( ProtoMajor: 1, ProtoMinor: 1, Body: http.NoBody, + GetBody: getNoBody, ContentLength: 0, Host: url.Host, }).WithContext(ctx) @@ -103,6 +104,7 @@ func (d *duplexHTTPCall) Send(payload messagePayload) (int64, error) { pipeReader, pipeWriter := io.Pipe() d.requestBodyWriter = pipeWriter d.request.Body = pipeReader + d.request.GetBody = nil // GetBody not supported for client streaming d.request.ContentLength = -1 go d.makeRequest() // concurrent request } @@ -148,10 +150,6 @@ func (d *duplexHTTPCall) sendUnary(payload messagePayload) (int64, error) { // returning from Send. This ensures that the payload can be reused // after Send returns. See [http.RoundTripper] for more details. defer payloadBody.Wait() - } else { - d.request.GetBody = func() (io.ReadCloser, error) { - return http.NoBody, nil - } } d.makeRequest() // synchronous request if err := d.ctx.Err(); err != nil { @@ -330,6 +328,11 @@ func (d *duplexHTTPCall) makeRequest() { } } +// getNoBody is a GetBody function for http.NoBody. +func getNoBody() (io.ReadCloser, error) { + return http.NoBody, nil +} + // messagePayload is a sized and seekable message payload. The interface is // implemented by [*bytes.Reader] and *envelope. Reads must be non-blocking. type messagePayload interface { From 12c36b0c0cfb2d66b042888d5fc927bcaf20b168 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Thu, 14 Dec 2023 11:05:43 -0500 Subject: [PATCH 3/5] Remove ctx check on response errors --- duplex_http_call.go | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/duplex_http_call.go b/duplex_http_call.go index e1b21859..ab9b79f2 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -72,16 +72,15 @@ func newDuplexHTTPCall( // NewRequestContext and doesn't effect the actual version // being transmitted. request := (&http.Request{ - Method: http.MethodPost, - URL: url, - Header: header, - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Body: http.NoBody, - GetBody: getNoBody, - ContentLength: 0, - Host: url.Host, + Method: http.MethodPost, + URL: url, + Header: header, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Body: http.NoBody, + GetBody: getNoBody, + Host: url.Host, }).WithContext(ctx) return &duplexHTTPCall{ ctx: ctx, @@ -152,8 +151,12 @@ func (d *duplexHTTPCall) sendUnary(payload messagePayload) (int64, error) { defer payloadBody.Wait() } d.makeRequest() // synchronous request - if err := d.ctx.Err(); err != nil { - return 0, wrapIfContextError(err) + if d.responseErr != nil { + // Check on response errors for context errors. Other errors are + // handled on read. + if err := d.ctx.Err(); err != nil { + return 0, wrapIfContextError(err) + } } return payloadLength, nil } @@ -418,7 +421,7 @@ func newPayloadCloser(payload messagePayload) *payloadCloser { func (p *payloadCloser) Read(dst []byte) (readN int, err error) { p.mu.Lock() defer p.mu.Unlock() - if p.payload == nil { + if p.isDone { return 0, io.EOF } readN, err = p.payload.Read(dst) From 7686b3ec09bbdae588a6650c164a0ee18189234c Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Thu, 14 Dec 2023 11:08:59 -0500 Subject: [PATCH 4/5] Clarify Content-Length checks --- connect_ext_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/connect_ext_test.go b/connect_ext_test.go index 9091a17f..76f47499 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -455,13 +455,18 @@ func TestServer(t *testing.T) { case pingv1connect.PingServicePingProcedure, pingv1connect.PingServiceFailProcedure, pingv1connect.PingServiceCountUpProcedure: + // Unary requests set Content-Length to the length of the request body. if request.ContentLength < 0 { t.Errorf("%s: expected Content-Length >= 0, got %d", request.URL.Path, request.ContentLength) } - default: + case pingv1connect.PingServiceSumProcedure, + pingv1connect.PingServiceCumSumProcedure: + // Streaming requests set Content-Length to -1 or 0 on empty requests. if request.ContentLength > 0 { t.Errorf("%s: expected Content-Length -1 or 0, got %d", request.URL.Path, request.ContentLength) } + default: + t.Errorf("unexpected path %q", request.URL.Path) } pingHandler.ServeHTTP(response, request) })) From 6862f70f47b3b9f6036e106a38d5741428c0a69f Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 19 Dec 2023 17:48:07 +0000 Subject: [PATCH 5/5] Drop paylaod on Send return Removes the wait and instead simplifies to dropping the payload when the response is received. Should be valid for all unary servers. --- duplex_http_call.go | 65 +++++++++++++-------------------------------- 1 file changed, 19 insertions(+), 46 deletions(-) diff --git a/duplex_http_call.go b/duplex_http_call.go index ab9b79f2..fe048f45 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -145,10 +145,10 @@ func (d *duplexHTTPCall) sendUnary(payload messagePayload) (int64, error) { } return payloadBody, nil } - // Wait on the payloadBody to be completly read or closed before - // returning from Send. This ensures that the payload can be reused - // after Send returns. See [http.RoundTripper] for more details. - defer payloadBody.Wait() + // Release the payload ensuring that after Send returns the + // payload is safe to be reused. See [http.RoundTripper] for + // more details. + defer payloadBody.Release() } d.makeRequest() // synchronous request if d.responseErr != nil { @@ -398,62 +398,46 @@ func cloneURL(oldURL *url.URL) *url.URL { // payloadCloser is an [io.ReadCloser] that wraps a messagePayload. It's used to // implement the request body for unary calls. To safely reuse the buffer -// call Wait after the response is received to ensure the payload has been -// drained or closed. After Wait, the payload cannot be rewound. It's safe to -// call Close multiple times. +// call Release after the response is received to ensure the payload is safe for +// reuse. type payloadCloser struct { mu sync.Mutex - cond sync.Cond - payload messagePayload - isDone bool // true if the payload has been fully read + payload messagePayload // nil after Release } func newPayloadCloser(payload messagePayload) *payloadCloser { - closer := &payloadCloser{ + return &payloadCloser{ payload: payload, } - closer.cond.L = &closer.mu - return closer } -// Read implements [io.Reader], on error it signals that the payload has been -// fully read. +// Read implements [io.Reader]. func (p *payloadCloser) Read(dst []byte) (readN int, err error) { p.mu.Lock() defer p.mu.Unlock() - if p.isDone { + if p.payload == nil { return 0, io.EOF } - readN, err = p.payload.Read(dst) - if err != nil || p.payload.Len() == 0 { - p.completeWithLock() - } - return readN, err + return p.payload.Read(dst) } -// WriteTo implements [io.WriterTo]. It signals that the payload has been fully -// read. +// WriteTo implements [io.WriterTo]. func (p *payloadCloser) WriteTo(dst io.Writer) (int64, error) { p.mu.Lock() defer p.mu.Unlock() - if p.isDone { + if p.payload == nil { return 0, nil } - n, err := p.payload.WriteTo(dst) - p.completeWithLock() - return n, err + return p.payload.WriteTo(dst) } -// Close implements [io.Closer]. It signals completion of the payload. +// Close implements [io.Closer]. func (p *payloadCloser) Close() error { - p.mu.Lock() - defer p.mu.Unlock() - p.completeWithLock() return nil } // Rewind rewinds the payload to the beginning. It returns false if the -// payload has been discarded from a previous call to Wait. +// payload has been discarded from a previous call to Release. func (p *payloadCloser) Rewind() bool { p.mu.Lock() defer p.mu.Unlock() @@ -463,24 +447,13 @@ func (p *payloadCloser) Rewind() bool { if _, err := p.payload.Seek(0, io.SeekStart); err != nil { return false } - p.isDone = false return true } -// Wait blocks until the payload has been fully read or closed. After Wait, the -// payload is discarded and cannot be rewound. It's then safe to reuse the -// payload. -func (p *payloadCloser) Wait() { +// Release discards the payload. After Release is called, the payload cannot be +// rewound and the payload is safe to reuse. +func (p *payloadCloser) Release() { p.mu.Lock() - for !p.isDone { - p.cond.Wait() - } p.payload = nil p.mu.Unlock() } - -// completeWithLock signals that the payload has been fully read. -func (p *payloadCloser) completeWithLock() { - p.isDone = true - p.cond.Broadcast() -}