diff --git a/retry/adaptive.go b/retry/adaptive.go index 123201b..98f53cf 100644 --- a/retry/adaptive.go +++ b/retry/adaptive.go @@ -15,7 +15,6 @@ package retry import ( - "context" "fmt" "math/bits" "sync/atomic" @@ -24,27 +23,47 @@ import ( "github.com/ecodeclub/ekit/internal/errs" ) +var _ Strategy = (*AdaptiveTimeoutRetryStrategy)(nil) + type AdaptiveTimeoutRetryStrategy struct { strategy Strategy // 基础重试策略 threshold int // 超时比率阈值 (单位:比特数量) ringBuffer []uint64 // 比特环(滑动窗口存储超时信息) reqCount uint64 // 当前滑动窗口内超时的数量 ringBufferLen int // 滑动窗口长度 + } -func (s *AdaptiveTimeoutRetryStrategy) Next(ctx context.Context, err error) (time.Duration, bool) { - if err == nil { - s.markSuccess() - return 0, false - } +func (s *AdaptiveTimeoutRetryStrategy) Next() (time.Duration, bool) { failCount := s.getFailed() - s.markFail() if failCount >= s.threshold { return 0, false } - return s.strategy.Next(ctx, err) + return s.strategy.Next() } +func (s *AdaptiveTimeoutRetryStrategy) Report(err error) Strategy { + if err == nil { + s.markSuccess() + } else { + s.markFail() + } + return s +} + +//func (s *AdaptiveTimeoutRetryStrategy) Next(ctx context.Context, err error) (time.Duration, bool) { +// if err == nil { +// s.markSuccess() +// return 0, false +// } +// failCount := s.getFailed() +// s.markFail() +// if failCount >= s.threshold { +// return 0, false +// } +// return s.strategy.Next(ctx, err) +//} + func (s *AdaptiveTimeoutRetryStrategy) markSuccess() { count := atomic.AddUint64(&s.reqCount, 1) count = count % (uint64(64) * uint64(len(s.ringBuffer))) diff --git a/retry/adaptive_test.go b/retry/adaptive_test.go index 34a3fe9..f6cd977 100644 --- a/retry/adaptive_test.go +++ b/retry/adaptive_test.go @@ -15,7 +15,6 @@ package retry import ( - "context" "errors" "fmt" "sync" @@ -89,32 +88,24 @@ func TestAdaptiveTimeoutRetryStrategy_Next(t *testing.T) { tests := []struct { name string - err error wantDelay time.Duration wantOk bool }{ { name: "error below threshold", - err: errors.New("test error"), wantDelay: 1 * time.Second, wantOk: true, }, { name: "error above threshold", - err: errors.New("test error"), wantDelay: 1 * time.Second, wantOk: true, }, - { - name: "not retry", - wantDelay: 0, - wantOk: false, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - delay, ok := strategy.Next(context.Background(), tt.err) + delay, ok := strategy.Next() assert.Equal(t, tt.wantDelay, delay) assert.Equal(t, tt.wantOk, ok) }) @@ -147,7 +138,8 @@ func TestAdaptiveTimeoutRetryStrategy_Next_Concurrent(t *testing.T) { if index >= 1500 { err = mockErr } - _, allowed := strategy.Next(context.Background(), err) + strategy.Report(err) + _, allowed := strategy.Next() if err != nil { // 失败请求的统计 if allowed { @@ -187,11 +179,10 @@ func ExampleAdaptiveTimeoutRetryStrategy_Next() { fmt.Println(err) return } - nextErr := errors.New("test error") - interval, ok := strategy.Next(context.Background(), nextErr) + interval, ok := strategy.Next() for ok { fmt.Println(interval) - interval, ok = strategy.Next(context.Background(), nextErr) + interval, ok = strategy.Next() } // Output: // 1s @@ -209,6 +200,10 @@ func ExampleAdaptiveTimeoutRetryStrategy_Next() { type MockStrategy struct { } -func (m MockStrategy) Next(ctx context.Context, err error) (time.Duration, bool) { +func (m MockStrategy) Next() (time.Duration, bool) { return 1 * time.Second, true } + +func (m MockStrategy) Report(err error) Strategy { + return m +} diff --git a/retry/exponential.go b/retry/exponential.go index 2af18eb..8d06e5a 100644 --- a/retry/exponential.go +++ b/retry/exponential.go @@ -15,7 +15,6 @@ package retry import ( - "context" "math" "sync/atomic" "time" @@ -23,6 +22,8 @@ import ( "github.com/ecodeclub/ekit/internal/errs" ) +var _ Strategy = (*ExponentialBackoffRetryStrategy)(nil) + // ExponentialBackoffRetryStrategy 指数退避重试 type ExponentialBackoffRetryStrategy struct { // 初始重试间隔 @@ -50,14 +51,11 @@ func NewExponentialBackoffRetryStrategy(initialInterval, maxInterval time.Durati }, nil } -func (s *ExponentialBackoffRetryStrategy) Next(ctx context.Context, err error) (time.Duration, bool) { - if err != nil { - return s.next() - } - return 0, false +func (s *ExponentialBackoffRetryStrategy) Report(err error) Strategy { + return s } -func (s *ExponentialBackoffRetryStrategy) next() (time.Duration, bool) { +func (s *ExponentialBackoffRetryStrategy) Next() (time.Duration, bool) { retries := atomic.AddInt32(&s.retries, 1) if s.maxRetries <= 0 || retries <= s.maxRetries { if reached, ok := s.maxIntervalReached.Load().(bool); ok && reached { diff --git a/retry/exponential_test.go b/retry/exponential_test.go index c5e5b0a..d8f4c4e 100644 --- a/retry/exponential_test.go +++ b/retry/exponential_test.go @@ -16,7 +16,6 @@ package retry import ( "context" - "errors" "fmt" "testing" "time" @@ -84,16 +83,14 @@ func TestNewExponentialBackoffRetryStrategy_New(t *testing.T) { func TestExponentialBackoffRetryStrategy_Next(t *testing.T) { testCases := []struct { name string - nextErr error ctx context.Context strategy *ExponentialBackoffRetryStrategy wantIntervals []time.Duration }{ { - name: "stop if retries reaches maxRetries", - ctx: context.Background(), - nextErr: errors.New("test error"), + name: "stop if retries reaches maxRetries", + ctx: context.Background(), strategy: func() *ExponentialBackoffRetryStrategy { s, err := NewExponentialBackoffRetryStrategy(1*time.Second, 10*time.Second, 3) require.NoError(t, err) @@ -103,9 +100,8 @@ func TestExponentialBackoffRetryStrategy_Next(t *testing.T) { wantIntervals: []time.Duration{1 * time.Second, 2 * time.Second, 4 * time.Second}, }, { - name: "initialInterval over maxInterval", - ctx: context.Background(), - nextErr: errors.New("test error"), + name: "initialInterval over maxInterval", + ctx: context.Background(), strategy: func() *ExponentialBackoffRetryStrategy { s, err := NewExponentialBackoffRetryStrategy(1*time.Second, 4*time.Second, 5) require.NoError(t, err) @@ -114,22 +110,12 @@ func TestExponentialBackoffRetryStrategy_Next(t *testing.T) { wantIntervals: []time.Duration{1 * time.Second, 2 * time.Second, 4 * time.Second, 4 * time.Second, 4 * time.Second}, }, - { - name: "not retry", - ctx: context.Background(), - strategy: func() *ExponentialBackoffRetryStrategy { - s, err := NewExponentialBackoffRetryStrategy(1*time.Second, 4*time.Second, 5) - require.NoError(t, err) - return s - }(), - wantIntervals: []time.Duration{}, - }, } for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { intervals := make([]time.Duration, 0) for { - if interval, ok := tt.strategy.Next(tt.ctx, tt.nextErr); ok { + if interval, ok := tt.strategy.Next(); ok { intervals = append(intervals, interval) } else { break @@ -160,10 +146,10 @@ func ExampleExponentialBackoffRetryStrategy_Next() { fmt.Println(err) return } - interval, ok := retry.next() + interval, ok := retry.Next() for ok { fmt.Println(interval) - interval, ok = retry.next() + interval, ok = retry.Next() } // Output: // 1s diff --git a/retry/fixed_internal.go b/retry/fixed_internal.go index 0aa45d7..ea210db 100644 --- a/retry/fixed_internal.go +++ b/retry/fixed_internal.go @@ -15,13 +15,14 @@ package retry import ( - "context" "sync/atomic" "time" "github.com/ecodeclub/ekit/internal/errs" ) +var _ Strategy = (*FixedIntervalRetryStrategy)(nil) + // FixedIntervalRetryStrategy 等间隔重试 type FixedIntervalRetryStrategy struct { maxRetries int32 // 最大重试次数,如果是 0 或负数,表示无限重试 @@ -39,17 +40,14 @@ func NewFixedIntervalRetryStrategy(interval time.Duration, maxRetries int32) (*F }, nil } -func (s *FixedIntervalRetryStrategy) Next(ctx context.Context, err error) (time.Duration, bool) { - if err != nil { - return s.next() - } - return 0, false -} - -func (s *FixedIntervalRetryStrategy) next() (time.Duration, bool) { +func (s *FixedIntervalRetryStrategy) Next() (time.Duration, bool) { retries := atomic.AddInt32(&s.retries, 1) if s.maxRetries <= 0 || retries <= s.maxRetries { return s.interval, true } return 0, false } + +func (s *FixedIntervalRetryStrategy) Report(err error) Strategy { + return s +} diff --git a/retry/fixed_internal_test.go b/retry/fixed_internal_test.go index dab5ee8..e9deabb 100644 --- a/retry/fixed_internal_test.go +++ b/retry/fixed_internal_test.go @@ -16,7 +16,6 @@ package retry import ( "context" - "errors" "fmt" "testing" "time" @@ -32,7 +31,6 @@ func TestFixedIntervalRetryStrategy_Next(t *testing.T) { testCases := []struct { name string - nextErr error ctx context.Context s *FixedIntervalRetryStrategy interval time.Duration @@ -40,9 +38,8 @@ func TestFixedIntervalRetryStrategy_Next(t *testing.T) { isContinue bool }{ { - name: "init case, retries 0", - ctx: context.Background(), - nextErr: errors.New("test error"), + name: "init case, retries 0", + ctx: context.Background(), s: &FixedIntervalRetryStrategy{ maxRetries: 3, interval: time.Second, @@ -51,9 +48,8 @@ func TestFixedIntervalRetryStrategy_Next(t *testing.T) { isContinue: true, }, { - name: "retries equals to MaxRetries 3 after the increase", - ctx: context.Background(), - nextErr: errors.New("test error"), + name: "retries equals to MaxRetries 3 after the increase", + ctx: context.Background(), s: &FixedIntervalRetryStrategy{ maxRetries: 3, interval: time.Second, @@ -63,9 +59,8 @@ func TestFixedIntervalRetryStrategy_Next(t *testing.T) { isContinue: true, }, { - name: "retries over MaxRetries after the increase", - ctx: context.Background(), - nextErr: errors.New("test error"), + name: "retries over MaxRetries after the increase", + ctx: context.Background(), s: &FixedIntervalRetryStrategy{ maxRetries: 3, interval: time.Second, @@ -75,9 +70,8 @@ func TestFixedIntervalRetryStrategy_Next(t *testing.T) { isContinue: false, }, { - name: "MaxRetries equals to 0", - ctx: context.Background(), - nextErr: errors.New("test error"), + name: "MaxRetries equals to 0", + ctx: context.Background(), s: &FixedIntervalRetryStrategy{ maxRetries: 0, interval: time.Second, @@ -86,9 +80,8 @@ func TestFixedIntervalRetryStrategy_Next(t *testing.T) { isContinue: true, }, { - name: "negative MaxRetries", - ctx: context.Background(), - nextErr: errors.New("test error"), + name: "negative MaxRetries", + ctx: context.Background(), s: &FixedIntervalRetryStrategy{ maxRetries: -1, interval: time.Second, @@ -97,19 +90,10 @@ func TestFixedIntervalRetryStrategy_Next(t *testing.T) { interval: time.Second, isContinue: true, }, - { - name: "not retry", - ctx: context.Background(), - s: &FixedIntervalRetryStrategy{ - maxRetries: -1, - interval: time.Second, - retries: 0, - }, - }, } for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { - interval, isContinue := tt.s.Next(tt.ctx, tt.nextErr) + interval, isContinue := tt.s.Next() assert.Equal(t, tt.interval, interval) assert.Equal(t, tt.isContinue, isContinue) }) @@ -176,7 +160,7 @@ func testNext4InfiniteRetry(t *testing.T, maxRetries int32) { intervals := make([]time.Duration, 0, n) for i := 0; i < n; i++ { - res, _ := s.next() + res, _ := s.Next() intervals = append(intervals, res) } assert.Equal(t, wantIntervals, intervals) @@ -188,10 +172,10 @@ func ExampleFixedIntervalRetryStrategy_Next() { fmt.Println(err) return } - interval, ok := retry.next() + interval, ok := retry.Next() for ok { fmt.Println(interval) - interval, ok = retry.next() + interval, ok = retry.Next() } // Output: // 1s diff --git a/retry/retry.go b/retry/retry.go index 2c10a7e..fe6b9e1 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -41,7 +41,7 @@ func Retry(ctx context.Context, if err == nil { return nil } - duration, ok := s.Next(ctx, err) + duration, ok := s.Next() if !ok { return errs.NewErrRetryExhausted(err) } diff --git a/retry/types.go b/retry/types.go index c431790..11ae405 100644 --- a/retry/types.go +++ b/retry/types.go @@ -15,11 +15,11 @@ package retry import ( - "context" "time" ) type Strategy interface { // Next 返回下一次重试的间隔,如果不需要继续重试,那么第二参数返回 false - Next(ctx context.Context, err error) (time.Duration, bool) + Next() (time.Duration, bool) + Report(err error) Strategy }