Skip to content

Commit

Permalink
refactor: make CopyBody exported from rest package
Browse files Browse the repository at this point in the history
  • Loading branch information
benwaples committed Apr 2, 2024
1 parent f5003ec commit 970c6c7
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 65 deletions.
16 changes: 1 addition & 15 deletions client/rest/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,23 +218,9 @@ func (s *restClient) Send(req *http.Request) (*http.Response, error) {
return s.send(req)
}

func copyBody(req *http.Request) ([]byte, error) {
var (
body []byte
err error
)
if req.Body != nil {
body, err = io.ReadAll(req.Body)
if body != nil {
req.Body = io.NopCloser(bytes.NewBuffer(body))
}
}
return body, err
}

func (s *restClient) send(req *http.Request) (*http.Response, error) {
// copy the bytes in case we need to retry the request
if body, err := copyBody(req); err != nil {
if body, err := CopyBody(req); err != nil {
return nil, err
} else {
var (
Expand Down
4 changes: 2 additions & 2 deletions client/rest/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func TestClosedConnection(t *testing.T) {

// make request in separate goroutine so its not blocking after we validated the retry
go func() {
client.Authenticate() // Authenticate()because it uses the internal client.send method.
// the above request should block this from running, however if it does then the test fails.
client.Authenticate() // Authenticate() because it uses the internal client.send method.
// the above request should block the request from completing, however if it does then the test fails.
requestCompleted = true
}()

Expand Down
21 changes: 18 additions & 3 deletions client/rest/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package rest

import (
"bytes"
"crypto/sha1"
"crypto/x509"
"encoding/base64"
Expand All @@ -26,6 +27,7 @@ import (
"fmt"
"io"
"math"
"net/http"
"strings"
"time"

Expand Down Expand Up @@ -122,10 +124,9 @@ func x5t(certificate string) (string, error) {
}
}

var ClosedConnectionMsg = "An existing connection was forcibly closed by the remote host."

func IsClosedConnectionErr(err error) bool {
closedFromClient := strings.Contains(err.Error(), ClosedConnectionMsg)
var closedConnectionMsg = "An existing connection was forcibly closed by the remote host."
closedFromClient := strings.Contains(err.Error(), closedConnectionMsg)
// Mocking http.Do would require a larger refactor,
// so closedFromTestCase is used for testing only.
closedFromTestCase := strings.HasSuffix(err.Error(), ": EOF")
Expand All @@ -136,3 +137,17 @@ func ExponentialBackoff(retry int, maxRetries int) {
backoff := math.Pow(5, float64(retry+1))
time.Sleep(time.Second * time.Duration(backoff))
}

func CopyBody(req *http.Request) ([]byte, error) {
var (
body []byte
err error
)
if req.Body != nil {
body, err = io.ReadAll(req.Body)
if body != nil {
req.Body = io.NopCloser(bytes.NewBuffer(body))
}
}
return body, err
}
83 changes: 38 additions & 45 deletions cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"os/signal"
"runtime"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -229,8 +228,8 @@ func ingest(ctx context.Context, bheUrl url.URL, bheClient *http.Client, in <-ch
log.Error(err, unrecoverableErrMsg)
return true
} else if response.StatusCode == http.StatusGatewayTimeout || response.StatusCode == http.StatusServiceUnavailable || response.StatusCode == http.StatusBadGateway {
serverError := fmt.Errorf("received server error %d while requesting %v", response.StatusCode, endpoint)
log.Error(serverError, "attempt %d/%d", retry+1, maxRetries)
serverError := fmt.Errorf("received server error %d while requesting %v;", response.StatusCode, endpoint)
log.Error(serverError, "attempt %d/%d; trying again", retry+1, maxRetries)

rest.ExponentialBackoff(retry, maxRetries)

Expand Down Expand Up @@ -266,61 +265,55 @@ func ingest(ctx context.Context, bheUrl url.URL, bheClient *http.Client, in <-ch
// TODO: create/use a proper bloodhound client
func do(bheClient *http.Client, req *http.Request) (*http.Response, error) {
var (
body []byte
res *http.Response
err error
maxRetries = 3
)

// copy the bytes in case we need to retry the request
if req.Body != nil {
if body, err = io.ReadAll(req.Body); err != nil {
return nil, err
}
if body != nil {
req.Body = io.NopCloser(bytes.NewBuffer(body))
}
}
if body, err := rest.CopyBody(req); err != nil {
return nil, err
} else {
for retry := 0; retry < maxRetries; retry++ {
// Reusing http.Request requires rewinding the request body
// back to a working state
if body != nil && retry > 0 {
req.Body = io.NopCloser(bytes.NewBuffer(body))
}

for retry := 0; retry < maxRetries; retry++ {
// Reusing http.Request requires rewinding the request body
// back to a working state
if body != nil && retry > 0 {
req.Body = io.NopCloser(bytes.NewBuffer(body))
}
if res, err = bheClient.Do(req); err != nil {
if rest.IsClosedConnectionErr(err) {
// try again on force closed connections
log.Error(err, "remote host force closed connection while requesting %s; attempt %d/%d; trying again\n", req.URL, retry+1, maxRetries)
rest.ExponentialBackoff(retry, maxRetries)
continue
}
// normal client error, dont attempt again
return nil, err
} else if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
if res.StatusCode >= http.StatusInternalServerError {
// Internal server error, backoff and try again.
serverError := fmt.Errorf("received server error %d while requesting %v", res.StatusCode, req.URL)
log.Error(serverError, "attempt %d/%d", retry+1, maxRetries)

if res, err = bheClient.Do(req); err != nil {
if strings.Contains(err.Error(), rest.ClosedConnectionMsg) || strings.HasSuffix(err.Error(), ": EOF") {
// try again on force closed connections
log.Error(err, "remote host force closed connection while requesting %s; attempt %d/%d; trying again\n", req.URL, retry+1, maxRetries)
rest.ExponentialBackoff(retry, maxRetries)
continue
}
// normal client error, dont attempt again
return nil, err
} else if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
if res.StatusCode >= http.StatusInternalServerError {
// Internal server error, backoff and try again.
serverError := fmt.Errorf("received server error %d while requesting %v", res.StatusCode, req.URL)
log.Error(serverError, "attempt %d/%d", retry+1, maxRetries)

rest.ExponentialBackoff(retry, maxRetries)
continue
}
// bad request we do not need to retry
var body json.RawMessage
defer res.Body.Close()
if err := json.NewDecoder(res.Body).Decode(&body); err != nil {
return nil, fmt.Errorf("received unexpected response code from %v: %s; failure reading response body", req.URL, res.Status)
rest.ExponentialBackoff(retry, maxRetries)
continue
}
// bad request we do not need to retry
var body json.RawMessage
defer res.Body.Close()
if err := json.NewDecoder(res.Body).Decode(&body); err != nil {
return nil, fmt.Errorf("received unexpected response code from %v: %s; failure reading response body", req.URL, res.Status)
} else {
return nil, fmt.Errorf("received unexpected response code from %v: %s %s", req.URL, res.Status, body)
}
} else {
return nil, fmt.Errorf("received unexpected response code from %v: %s %s", req.URL, res.Status, body)
return res, nil
}
} else {
return res, nil
}
}

return nil, fmt.Errorf("unable to complete request | url=%s | attempts=%d | ERR=%w", req.URL, maxRetries, err)
return nil, fmt.Errorf("unable to complete request to url=%s; attempts=%d; ERR=%w", req.URL, maxRetries, err)
}

type basicResponse[T any] struct {
Expand Down

0 comments on commit 970c6c7

Please sign in to comment.