diff --git a/retrier.go b/retrier.go index 9bf36ee..d6ba25a 100644 --- a/retrier.go +++ b/retrier.go @@ -21,6 +21,7 @@ type Retrier struct { breakNext bool sleepFunc func(time.Duration) + timeout time.Duration intervalCalculator Strategy strategyType string @@ -72,6 +73,30 @@ func WithMaxAttempts(maxAttempts int) retrierOpt { } } +func WithTimeout(t time.Duration) retrierOpt { + if t < 0 { + panic("timeout must be a positive duration") + } + + return func(r *Retrier) { + if r.maxAttempts == 0 { + // It's possible to mix and match timeouts and max attempts, but if we've set a timeout and no max attempts, then + // we should just go until the timeout + r.maxAttempts = math.MaxInt + } + + r.timeout = t + } +} + +func TryUntil(t time.Time) retrierOpt { + if t.Before(time.Now()) { + panic("until time must be in the future") + } + + return WithTimeout(time.Until(t)) +} + func WithRand(rand *rand.Rand) retrierOpt { return func(r *Retrier) { r.rand = rand @@ -116,7 +141,8 @@ func WithSleepFunc(f func(time.Duration)) retrierOpt { // the retrier func NewRetrier(opts ...retrierOpt) *Retrier { r := &Retrier{ - rand: defaultRandom, + rand: defaultRandom, + timeout: time.Duration(math.MaxInt64), } for _, o := range opts { @@ -125,8 +151,8 @@ func NewRetrier(opts ...retrierOpt) *Retrier { // We use panics here rather than returning an error because all of these are logical issues caused by the programmer, // they should never occur in normal running, and can't be logically recovered from - if r.maxAttempts == 0 && !r.forever { - panic("retriers must either run forever, or have a maximum attempt count") + if r.maxAttempts == 0 && !r.forever && r.timeout == 0 { + panic("retriers must run forever, have a maximum attempt count, or have a timeout") } if r.maxAttempts < 0 { @@ -228,6 +254,9 @@ func (r *Retrier) Do(callback func(*Retrier) error) error { // DoWithContext is a context-aware variant of Do. func (r *Retrier) DoWithContext(ctx context.Context, callback func(*Retrier) error) error { + timeouter := time.NewTimer(r.timeout) // Defaults to math.MaxInt64 unless overridden + defer timeouter.Stop() + for { // Perform the action the user has requested we retry err := callback(r) @@ -251,19 +280,21 @@ func (r *Retrier) DoWithContext(ctx context.Context, callback func(*Retrier) err return err } - if err := r.sleepOrDone(ctx, nextInterval); err != nil { + if err := r.sleepOrDone(ctx, nextInterval, timeouter); err != nil { return err } } } -func (r *Retrier) sleepOrDone(ctx context.Context, nextInterval time.Duration) error { +func (r *Retrier) sleepOrDone(ctx context.Context, nextInterval time.Duration, timeouter *time.Timer) error { if r.sleepFunc == nil { t := time.NewTimer(nextInterval) defer t.Stop() select { case <-t.C: return nil + case <-timeouter.C: + return fmt.Errorf("retrier timed out after %v", r.timeout) case <-ctx.Done(): return ctx.Err() } @@ -275,6 +306,8 @@ func (r *Retrier) sleepOrDone(ctx context.Context, nextInterval time.Duration) e close(sleepCh) }() select { + case <-timeouter.C: + return fmt.Errorf("retrier timed out after %v", r.timeout) case <-sleepCh: return nil case <-ctx.Done(): diff --git a/retrier_test.go b/retrier_test.go index 3803914..1c864fe 100644 --- a/retrier_test.go +++ b/retrier_test.go @@ -112,6 +112,23 @@ func TestShouldGiveUp_WithMaxAttempts(t *testing.T) { assert.Equal(t, 3, callcount) } +func TestTimeout(t *testing.T) { + t.Parallel() + + callCount := 0 + err := NewRetrier( + WithStrategy(Constant(5*time.Millisecond)), + WithTimeout(500*time.Millisecond), + ).Do(func(_ *Retrier) error { + callCount += 1 + return errDummy + }) + + assert.Error(t, err) + assert.Equal(t, err, errors.New("retrier timed out after 500ms")) + assert.Equal(t, 100, callCount) +} + func TestShouldGiveUp_Break(t *testing.T) { t.Parallel()