From 5608ee700801d372beffd3c94359f874eafadaf6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sandor=20Sz=C3=BCcs?= Date: Wed, 6 Mar 2024 18:37:44 +0100 Subject: [PATCH] feature: httpclient supporting retry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sandor Szücs --- net/httpclient.go | 89 ++++++++++++++++++++-- net/httpclient_test.go | 169 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 250 insertions(+), 8 deletions(-) diff --git a/net/httpclient.go b/net/httpclient.go index bf9c7d7aa5..833fab3898 100644 --- a/net/httpclient.go +++ b/net/httpclient.go @@ -1,6 +1,7 @@ package net import ( + "bytes" "crypto/tls" "fmt" "io" @@ -23,15 +24,54 @@ const ( defaultRefreshInterval = 5 * time.Minute ) +type mybuf struct{ *bytes.Buffer } + +func (buf *mybuf) Close() error { + return nil +} + +type copyBodyStream struct { + left int + buf *mybuf + input io.ReadCloser +} + +func newCopyBodyStream(left int, buf *bytes.Buffer, rc io.ReadCloser) *copyBodyStream { + return ©BodyStream{ + left: left, + buf: &mybuf{Buffer: buf}, + input: rc, + } +} + +func (cb *copyBodyStream) Read(p []byte) (n int, err error) { + n, err = cb.input.Read(p) + if cb.left > 0 && n > 0 { + m := min(n, cb.left) + cb.buf.Write(p[:m]) + cb.left -= m + } + return n, err +} + +func (cb *copyBodyStream) Close() error { + return cb.input.Close() +} + +func (cb *copyBodyStream) GetBody() io.ReadCloser { + return cb.buf +} + // Client adds additional features like Bearer token injection, and // opentracing to the wrapped http.Client with the same interface as // http.Client from the stdlib. type Client struct { - once sync.Once - client http.Client - tr *Transport - log logging.Logger - sr secrets.SecretsReader + once sync.Once + client http.Client + tr *Transport + log logging.Logger + sr secrets.SecretsReader + retryBuffers *sync.Map } // NewClient creates a wrapped http.Client and uses Transport to @@ -67,9 +107,10 @@ func NewClient(o Options) *Client { Transport: tr, CheckRedirect: o.CheckRedirect, }, - tr: tr, - log: o.Log, - sr: sr, + tr: tr, + log: o.Log, + sr: sr, + retryBuffers: &sync.Map{}, } return c @@ -125,9 +166,41 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { req.Header.Set("Authorization", "Bearer "+string(b)) } } + if req.Body != nil && req.Body != http.NoBody { + retryBuffer := newCopyBodyStream(int(req.ContentLength), &bytes.Buffer{}, req.Body) + c.retryBuffers.Store(req, retryBuffer) + req.Body = retryBuffer + } return c.client.Do(req) } +func (c *Client) Retry(req *http.Request) (*http.Response, error) { + if req.Body == nil || req.Body == http.NoBody { + return c.Do(req) + } + + if rc, err := req.GetBody(); err == nil { + println("req.GetBody() case") + c.retryBuffers.Delete(req) + req.Body = rc + return c.Do(req) + } + + println("our own retry buffer impl") + buf, ok := c.retryBuffers.Load(req) + if !ok { + return nil, fmt.Errorf("no retry possible, request not found: %s %s", req.Method, req.URL) + } + + retryBuffer, ok := buf.(*copyBodyStream) + if !ok { + return nil, fmt.Errorf("no retry possible, no retry buffer for request: %s %s", req.Method, req.URL) + } + req.Body = retryBuffer.GetBody() + + return c.Do(req) +} + // CloseIdleConnections delegates the call to the underlying // http.Client. func (c *Client) CloseIdleConnections() { diff --git a/net/httpclient_test.go b/net/httpclient_test.go index 0cb2ff7dd7..27466b36d9 100644 --- a/net/httpclient_test.go +++ b/net/httpclient_test.go @@ -1,6 +1,8 @@ package net import ( + "bytes" + "io" "net/http" "net/http/httptest" "net/url" @@ -324,3 +326,170 @@ func TestClientClosesIdleConnections(t *testing.T) { } rsp.Body.Close() } + +func TestTestClientRetry(t *testing.T) { + for _, tt := range []struct { + name string + method string + body string + }{ + { + name: "test GET", + method: "GET", + }, + { + name: "test POST", + method: "POST", + body: "hello POST", + }, + { + name: "test PATCH", + method: "PATCH", + body: "hello PATCH", + }, + { + name: "test PUT", + method: "PUT", + body: "hello PUT", + }} { + t.Run(tt.name, func(t *testing.T) { + i := 0 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if i == 0 { + i++ + w.WriteHeader(http.StatusBadGateway) + } + + got, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("got no data") + } + s := string(got) + if tt.body != s { + t.Fatalf("Failed to get the right data want: %q, got: %q", tt.body, s) + } + + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + noleak.Check(t) + + cli := NewClient(Options{}) + defer cli.Close() + + buf := bytes.NewBufferString(tt.body) + req, err := http.NewRequest(tt.method, backend.URL, buf) + if err != nil { + t.Fatal(err) + } + rsp, err := cli.Do(req) + if err != nil { + t.Fatal(err) + } + if rsp.StatusCode != http.StatusBadGateway { + t.Fatalf("unexpected status code: %s", rsp.Status) + } + + rsp, err = cli.Retry(req) + if rsp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %s", rsp.Status) + } + rsp.Body.Close() + }) + } +} + +func TestTestClientRetryConcurrentRequests(t *testing.T) { + for _, tt := range []struct { + name string + method string + body string + }{ + { + name: "test GET", + method: "GET", + }, + { + name: "test POST", + method: "POST", + body: "hello POST", + }, + { + name: "test PATCH", + method: "PATCH", + body: "hello PATCH", + }, + { + name: "test PUT", + method: "PUT", + body: "hello PUT", + }} { + t.Run(tt.name, func(t *testing.T) { + i := 0 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/ignore" { + w.WriteHeader(http.StatusOK) + return + } + + if i == 0 { + i++ + io.ReadAll(r.Body) + w.WriteHeader(http.StatusBadGateway) + return + } + + got, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("got no data") + } + s := string(got) + if tt.body != s { + t.Fatalf("Failed to get the right data want: %q, got: %q", tt.body, s) + } + + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + noleak.Check(t) + + cli := NewClient(Options{}) + defer cli.Close() + + quit := make(chan struct{}) + go func() { + for { + select { + case <-quit: + return + default: + } + cli.Get(backend.URL + "/ignore") + } + }() + + buf := bytes.NewBufferString(tt.body) + req, err := http.NewRequest(tt.method, backend.URL, buf) + if err != nil { + t.Fatal(err) + } + rsp, err := cli.Do(req) + if err != nil { + t.Fatal(err) + } + if rsp.StatusCode != http.StatusBadGateway { + t.Fatalf("unexpected status code: %s", rsp.Status) + } + + rsp, err = cli.Retry(req) + if rsp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %s", rsp.Status) + } + rsp.Body.Close() + + close(quit) + }) + } +}