Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for custom retry functions #429

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,16 @@ e.POST("/path").
WithRetryDelay(time.Second, time.Minute).
Expect().
Status(http.StatusOK)

// custom retry function
e.POST("/path").
WithMaxRetries(5).
WithCustomRetryFunc(func(resp *http.Response, err error) bool {
// your custom function here
return true
}).
Expect().
Status(http.StatusOK)
```

##### Subdomains and per-request URL
Expand Down
52 changes: 47 additions & 5 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ type Request struct {
redirectPolicy RedirectPolicy
maxRedirects int

retryPolicy RetryPolicy
maxRetries int
minRetryDelay time.Duration
maxRetryDelay time.Duration
sleepFn func(d time.Duration) <-chan time.Time
retryPolicy RetryPolicy
maxRetries int
minRetryDelay time.Duration
maxRetryDelay time.Duration
customRetryFunc func(resp *http.Response, err error) bool
sleepFn func(d time.Duration) <-chan time.Time

timeout time.Duration

Expand Down Expand Up @@ -721,6 +722,9 @@ const (

// RetryAllErrors enables retrying of any error or 4xx/5xx status code.
RetryAllErrors

// CustomRetryFunc adds a custom handler to execute on retry.
CustomRetryFunc
)

// WithRetryPolicy sets policy for retries.
Expand Down Expand Up @@ -852,6 +856,41 @@ func (r *Request) WithRetryDelay(minDelay, maxDelay time.Duration) *Request {
return r
}

// WithCustomRetryFunc sets a custom func to handle a result as a failure or a success.
//
// The handler function argument expects you to return `true` if you want to retry
// or `false` if you do not need to retry.
//
// Example:
//
// req := NewRequestC(config, "POST", "/path")
// req.WithRetryPolicy(1)
// req.WithCustomRetryFunc(func(resp *http.Response, err error) bool {
// return true
// })
// req.Expect().Status(http.StatusOK)
func (r *Request) WithCustomRetryFunc(
handler func(resp *http.Response, err error) bool,
) *Request {
opChain := r.chain.enter("WithCustomRetryFunc()")
defer opChain.leave()

r.mu.Lock()
defer r.mu.Unlock()

if opChain.failed() {
return r
}

if !r.checkOrder(opChain, "WithCustomRetryFunc()") {
return r
}

r.customRetryFunc = handler

return r
}

// WithWebsocketUpgrade enables upgrades the connection to websocket.
//
// At least the following fields are added to the request header:
Expand Down Expand Up @@ -2368,6 +2407,9 @@ func (r *Request) shouldRetry(resp *http.Response, err error) bool {

case RetryAllErrors:
return err != nil || isHTTPError

case CustomRetryFunc:
return r.customRetryFunc(resp, err)
}

return false
Expand Down
60 changes: 60 additions & 0 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ func TestRequest_FailedChain(t *testing.T) {
req.WithMaxRedirects(1)
req.WithRetryPolicy(RetryAllErrors)
req.WithMaxRetries(1)
req.WithCustomRetryFunc(func(r *http.Response, err error) bool {
return true
})
req.WithRetryDelay(time.Millisecond, time.Millisecond)
req.WithWebsocketUpgrade()
req.WithWebsocketDialer(
Expand Down Expand Up @@ -3766,6 +3769,14 @@ func TestRequest_Order(t *testing.T) {
req.WithMaxRetries(10)
},
},
{
name: "WithRetryCustomFunc after Expect",
afterFunc: func(req *Request) {
req.WithCustomRetryFunc(func(r *http.Response, err error) bool {
return true
})
},
},
{
name: "WithRetryDelay after Expect",
afterFunc: func(req *Request) {
Expand Down Expand Up @@ -3991,3 +4002,52 @@ func TestRequest_Panics(t *testing.T) {
assert.Panics(t, func() { newRequest(newMockChain(t), config, "GET", "") })
})
}

func TestRequest_RetriesCustomFunc(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should cover for both

  • WithCustomRetryFunc that returns true
  • WithCustomRetryFunc that returns false

Wdyt?


client := &mockClient{
resp: http.Response{
StatusCode: http.StatusBadRequest,
},
cb: func(req *http.Request) {

assert.Error(t, req.Context().Err(), context.Canceled.Error())

b, err := io.ReadAll(req.Body)
assert.NoError(t, err)
assert.Equal(t, "test body", string(b))
},
}

config := Config{
Client: client,
Reporter: newMockReporter(t),
}

ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately to trigger error

called := struct {
called bool
}{
called: true,
}

req := NewRequestC(config, http.MethodPost, "/url").
WithText("test body").
WithRetryPolicy(RetryAllErrors).
WithMaxRetries(2).
WithContext(ctx).
WithRetryDelay(0, 0).
WithCustomRetryFunc(func(resp *http.Response, err error) bool {
called.called = true
return true
})
req.chain.assert(t, success)

resp := req.Expect()
resp.chain.assert(t, failure)

// Should execute custom func
assert.True(t, called.called)
}