diff --git a/README.md b/README.md index 6a9971e0d..e0b0d680e 100644 --- a/README.md +++ b/README.md @@ -404,6 +404,16 @@ e.POST("/path"). WithRetryDelay(time.Second, time.Minute). Expect(). Status(http.StatusOK) + +// custom retry function +e.POST("/path"). + WithMaxRetries(5). + WithCustomRetryFunc(func(resp *http.Response, err error) bool { + // your custom function here + return true + }). + Expect(). + Status(http.StatusOK) ``` ##### Subdomains and per-request URL diff --git a/request.go b/request.go index 7e2312e34..fc1d11ed6 100644 --- a/request.go +++ b/request.go @@ -36,11 +36,12 @@ type Request struct { redirectPolicy RedirectPolicy maxRedirects int - retryPolicy RetryPolicy - maxRetries int - minRetryDelay time.Duration - maxRetryDelay time.Duration - sleepFn func(d time.Duration) <-chan time.Time + retryPolicy RetryPolicy + maxRetries int + minRetryDelay time.Duration + maxRetryDelay time.Duration + customRetryFunc func(resp *http.Response, err error) bool + sleepFn func(d time.Duration) <-chan time.Time timeout time.Duration @@ -721,6 +722,9 @@ const ( // RetryAllErrors enables retrying of any error or 4xx/5xx status code. RetryAllErrors + + // CustomRetryFunc adds a custom handler to execute on retry. + CustomRetryFunc ) // WithRetryPolicy sets policy for retries. @@ -852,6 +856,41 @@ func (r *Request) WithRetryDelay(minDelay, maxDelay time.Duration) *Request { return r } +// WithCustomRetryFunc sets a custom func to handle a result as a failure or a success. +// +// The handler function argument expects you to return `true` if you want to retry +// or `false` if you do not need to retry. +// +// Example: +// +// req := NewRequestC(config, "POST", "/path") +// req.WithRetryPolicy(1) +// req.WithCustomRetryFunc(func(resp *http.Response, err error) bool { +// return true +// }) +// req.Expect().Status(http.StatusOK) +func (r *Request) WithCustomRetryFunc( + handler func(resp *http.Response, err error) bool, +) *Request { + opChain := r.chain.enter("WithCustomRetryFunc()") + defer opChain.leave() + + r.mu.Lock() + defer r.mu.Unlock() + + if opChain.failed() { + return r + } + + if !r.checkOrder(opChain, "WithCustomRetryFunc()") { + return r + } + + r.customRetryFunc = handler + + return r +} + // WithWebsocketUpgrade enables upgrades the connection to websocket. // // At least the following fields are added to the request header: @@ -2368,6 +2407,9 @@ func (r *Request) shouldRetry(resp *http.Response, err error) bool { case RetryAllErrors: return err != nil || isHTTPError + + case CustomRetryFunc: + return r.customRetryFunc(resp, err) } return false diff --git a/request_test.go b/request_test.go index e83a31f40..5395f824e 100644 --- a/request_test.go +++ b/request_test.go @@ -46,6 +46,9 @@ func TestRequest_FailedChain(t *testing.T) { req.WithMaxRedirects(1) req.WithRetryPolicy(RetryAllErrors) req.WithMaxRetries(1) + req.WithCustomRetryFunc(func(r *http.Response, err error) bool { + return true + }) req.WithRetryDelay(time.Millisecond, time.Millisecond) req.WithWebsocketUpgrade() req.WithWebsocketDialer( @@ -3766,6 +3769,14 @@ func TestRequest_Order(t *testing.T) { req.WithMaxRetries(10) }, }, + { + name: "WithRetryCustomFunc after Expect", + afterFunc: func(req *Request) { + req.WithCustomRetryFunc(func(r *http.Response, err error) bool { + return true + }) + }, + }, { name: "WithRetryDelay after Expect", afterFunc: func(req *Request) { @@ -3991,3 +4002,52 @@ func TestRequest_Panics(t *testing.T) { assert.Panics(t, func() { newRequest(newMockChain(t), config, "GET", "") }) }) } + +func TestRequest_RetriesCustomFunc(t *testing.T) { + + client := &mockClient{ + resp: http.Response{ + StatusCode: http.StatusBadRequest, + }, + cb: func(req *http.Request) { + + assert.Error(t, req.Context().Err(), context.Canceled.Error()) + + b, err := io.ReadAll(req.Body) + assert.NoError(t, err) + assert.Equal(t, "test body", string(b)) + }, + } + + config := Config{ + Client: client, + Reporter: newMockReporter(t), + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately to trigger error + + called := struct { + called bool + }{ + called: true, + } + + req := NewRequestC(config, http.MethodPost, "/url"). + WithText("test body"). + WithRetryPolicy(RetryAllErrors). + WithMaxRetries(2). + WithContext(ctx). + WithRetryDelay(0, 0). + WithCustomRetryFunc(func(resp *http.Response, err error) bool { + called.called = true + return true + }) + req.chain.assert(t, success) + + resp := req.Expect() + resp.chain.assert(t, failure) + + // Should execute custom func + assert.True(t, called.called) +}