Skip to content

Commit

Permalink
feat: add support for custom retry functions
Browse files Browse the repository at this point in the history
Co-authored-by: Alexis Couvreur <[email protected]>
  • Loading branch information
nkzren and acouvreur committed Oct 4, 2023
1 parent f945d4b commit 82a5a05
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 5 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,16 @@ e.POST("/path").
WithRetryDelay(time.Second, time.Minute).
Expect().
Status(http.StatusOK)

// custom retry function
e.POST("/path").
WithMaxRetries(5).
WithCustomRetryFunc(func(resp *http.Response, err error) bool {
// your custom function here
return true
}).
Expect().
Status(http.StatusOK)
```

##### Subdomains and per-request URL
Expand Down
50 changes: 45 additions & 5 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ type Request struct {
redirectPolicy RedirectPolicy
maxRedirects int

retryPolicy RetryPolicy
maxRetries int
minRetryDelay time.Duration
maxRetryDelay time.Duration
sleepFn func(d time.Duration) <-chan time.Time
retryPolicy RetryPolicy
maxRetries int
minRetryDelay time.Duration
maxRetryDelay time.Duration
customRetryFunc func(resp *http.Response, err error) bool
sleepFn func(d time.Duration) <-chan time.Time

timeout time.Duration

Expand Down Expand Up @@ -721,6 +722,9 @@ const (

// RetryAllErrors enables retrying of any error or 4xx/5xx status code.
RetryAllErrors

// CustomRetryFunc adds a custom handler to execute on retry.
CustomRetryFunc
)

// WithRetryPolicy sets policy for retries.
Expand Down Expand Up @@ -852,6 +856,39 @@ func (r *Request) WithRetryDelay(minDelay, maxDelay time.Duration) *Request {
return r
}

// WithCustomRetryFunc sets a custom func to handle a result as a failure or a success.
//
// The handler function argument expects you to return `true` if you want to retry
// or `false` if you do not need to retry.
//
// Example:
//
// req := NewRequestC(config, "POST", "/path")
// req.WithRetryPolicy(1)
// req.WithCustomRetryFunc(func(resp *http.Response, err error) bool {
// return true
// })
// req.Expect().Status(http.StatusOK)
func (r *Request) WithCustomRetryFunc(handler func(resp *http.Response, err error) bool) *Request {

Check failure on line 872 in request.go

View workflow job for this annotation

GitHub Actions / Linters for root

line is 99 characters (lll)
opChain := r.chain.enter("WithCustomRetryFunc()")
defer opChain.leave()

r.mu.Lock()
defer r.mu.Unlock()

if opChain.failed() {
return r
}

if !r.checkOrder(opChain, "WithCustomRetryFunc()") {
return r
}

r.customRetryFunc = handler

return r
}

// WithWebsocketUpgrade enables upgrades the connection to websocket.
//
// At least the following fields are added to the request header:
Expand Down Expand Up @@ -2368,6 +2405,9 @@ func (r *Request) shouldRetry(resp *http.Response, err error) bool {

case RetryAllErrors:
return err != nil || isHTTPError

case CustomRetryFunc:
return r.customRetryFunc(resp, err)
}

return false
Expand Down
60 changes: 60 additions & 0 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ func TestRequest_FailedChain(t *testing.T) {
req.WithMaxRedirects(1)
req.WithRetryPolicy(RetryAllErrors)
req.WithMaxRetries(1)
req.WithCustomRetryFunc(func(r *http.Response, err error) bool {
return true
})
req.WithRetryDelay(time.Millisecond, time.Millisecond)
req.WithWebsocketUpgrade()
req.WithWebsocketDialer(
Expand Down Expand Up @@ -3766,6 +3769,14 @@ func TestRequest_Order(t *testing.T) {
req.WithMaxRetries(10)
},
},
{
name: "WithRetryCustomFunc after Expect",
afterFunc: func(req *Request) {
req.WithCustomRetryFunc(func(r *http.Response, err error) bool {
return true
})
},
},
{
name: "WithRetryDelay after Expect",
afterFunc: func(req *Request) {
Expand Down Expand Up @@ -3991,3 +4002,52 @@ func TestRequest_Panics(t *testing.T) {
assert.Panics(t, func() { newRequest(newMockChain(t), config, "GET", "") })
})
}

func TestRequest_RetriesCustomFunc(t *testing.T) {

client := &mockClient{
resp: http.Response{
StatusCode: http.StatusBadRequest,
},
cb: func(req *http.Request) {

assert.Error(t, req.Context().Err(), context.Canceled.Error())

b, err := io.ReadAll(req.Body)
assert.NoError(t, err)
assert.Equal(t, "test body", string(b))
},
}

config := Config{
Client: client,
Reporter: newMockReporter(t),
}

ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately to trigger error

called := struct {
called bool
}{
called: true,
}

req := NewRequestC(config, http.MethodPost, "/url").
WithText("test body").
WithRetryPolicy(RetryAllErrors).
WithMaxRetries(2).
WithContext(ctx).
WithRetryDelay(0, 0).
WithCustomRetryFunc(func(resp *http.Response, err error) bool {
called.called = true
return true
})
req.chain.assert(t, success)

resp := req.Expect()
resp.chain.assert(t, failure)

// Should execute custom func
assert.True(t, called.called)
}

0 comments on commit 82a5a05

Please sign in to comment.