diff --git a/v2/workloadapi/backoff.go b/v2/workloadapi/backoff.go index b6ef1ed..5ff9126 100644 --- a/v2/workloadapi/backoff.go +++ b/v2/workloadapi/backoff.go @@ -5,30 +5,51 @@ import ( "time" ) -// backoff defines an linear backoff policy. -type backoff struct { - InitialDelay time.Duration - MaxDelay time.Duration +// BackoffStrategy provides backoff facilities. +type BackoffStrategy interface { + // NewBackoff returns a new backoff for the strategy. The returned + // Backoff is in the same state that it would be in after a call to + // Reset(). + NewBackoff() Backoff +} + +// Backoff provides backoff for a workload API operation. +type Backoff interface { + // Next returns the next backoff period. + Next() time.Duration + + // Reset() resets the backoff. + Reset() +} + +type defaultBackoffStrategy struct{} + +func (defaultBackoffStrategy) NewBackoff() Backoff { + return newLinearBackoff() +} + +// linearBackoff defines an linear backoff policy. +type linearBackoff struct { + initialDelay time.Duration + maxDelay time.Duration n int } -func newBackoff() *backoff { - return &backoff{ - InitialDelay: time.Second, - MaxDelay: 30 * time.Second, +func newLinearBackoff() *linearBackoff { + return &linearBackoff{ + initialDelay: time.Second, + maxDelay: 30 * time.Second, n: 0, } } -// Duration returns the next wait period for the backoff. Not goroutine-safe. -func (b *backoff) Duration() time.Duration { +func (b *linearBackoff) Next() time.Duration { backoff := float64(b.n) + 1 - d := math.Min(b.InitialDelay.Seconds()*backoff, b.MaxDelay.Seconds()) + d := math.Min(b.initialDelay.Seconds()*backoff, b.maxDelay.Seconds()) b.n++ return time.Duration(d) * time.Second } -// Reset resets the backoff's state. -func (b *backoff) Reset() { +func (b *linearBackoff) Reset() { b.n = 0 } diff --git a/v2/workloadapi/backoff_test.go b/v2/workloadapi/backoff_test.go index 9e25e32..e9b4c76 100644 --- a/v2/workloadapi/backoff_test.go +++ b/v2/workloadapi/backoff_test.go @@ -7,34 +7,27 @@ import ( "github.com/stretchr/testify/require" ) -func TestBackoff(t *testing.T) { - new := func() *backoff { //nolint:all - b := newBackoff() - b.InitialDelay = time.Second - b.MaxDelay = 30 * time.Second - return b - } - - testUntilMax := func(t *testing.T, b *backoff) { +func TestLinearBackoff(t *testing.T) { + testUntilMax := func(t *testing.T, b *linearBackoff) { for i := 1; i < 30; i++ { - require.Equal(t, time.Duration(i)*time.Second, b.Duration()) + require.Equal(t, time.Duration(i)*time.Second, b.Next()) } - require.Equal(t, 30*time.Second, b.Duration()) - require.Equal(t, 30*time.Second, b.Duration()) - require.Equal(t, 30*time.Second, b.Duration()) + require.Equal(t, 30*time.Second, b.Next()) + require.Equal(t, 30*time.Second, b.Next()) + require.Equal(t, 30*time.Second, b.Next()) } t.Run("test max", func(t *testing.T) { t.Parallel() - b := new() + b := newLinearBackoff() testUntilMax(t, b) }) t.Run("test reset", func(t *testing.T) { t.Parallel() - b := new() + b := newLinearBackoff() testUntilMax(t, b) b.Reset() diff --git a/v2/workloadapi/client.go b/v2/workloadapi/client.go index 4d5de5d..7739798 100644 --- a/v2/workloadapi/client.go +++ b/v2/workloadapi/client.go @@ -119,7 +119,7 @@ func (c *Client) FetchX509Bundles(ctx context.Context) (*x509bundle.Set, error) // WatchX509Bundles watches for changes to the X.509 bundles. The watcher receives // the updated X.509 bundles. func (c *Client) WatchX509Bundles(ctx context.Context, watcher X509BundleWatcher) error { - backoff := newBackoff() + backoff := c.config.backoffStrategy.NewBackoff() for { err := c.watchX509Bundles(ctx, watcher, backoff) watcher.OnX509BundlesWatchError(err) @@ -152,7 +152,7 @@ func (c *Client) FetchX509Context(ctx context.Context) (*X509Context, error) { // WatchX509Context watches for updates to the X.509 context. The watcher // receives the updated X.509 context. func (c *Client) WatchX509Context(ctx context.Context, watcher X509ContextWatcher) error { - backoff := newBackoff() + backoff := c.config.backoffStrategy.NewBackoff() for { err := c.watchX509Context(ctx, watcher, backoff) watcher.OnX509ContextWatchError(err) @@ -224,7 +224,7 @@ func (c *Client) FetchJWTBundles(ctx context.Context) (*jwtbundle.Set, error) { // WatchJWTBundles watches for changes to the JWT bundles. The watcher receives // the updated JWT bundles. func (c *Client) WatchJWTBundles(ctx context.Context, watcher JWTBundleWatcher) error { - backoff := newBackoff() + backoff := c.config.backoffStrategy.NewBackoff() for { err := c.watchJWTBundles(ctx, watcher, backoff) watcher.OnJWTBundlesWatchError(err) @@ -258,7 +258,7 @@ func (c *Client) newConn(ctx context.Context) (*grpc.ClientConn, error) { return grpc.DialContext(ctx, c.config.address, c.config.dialOptions...) //nolint:staticcheck // preserve backcompat with WithDialOptions option } -func (c *Client) handleWatchError(ctx context.Context, err error, backoff *backoff) error { +func (c *Client) handleWatchError(ctx context.Context, err error, backoff Backoff) error { code := status.Code(err) if code == codes.Canceled { return err @@ -270,7 +270,7 @@ func (c *Client) handleWatchError(ctx context.Context, err error, backoff *backo } c.config.log.Errorf("Failed to watch the Workload API: %v", err) - retryAfter := backoff.Duration() + retryAfter := backoff.Next() c.config.log.Debugf("Retrying watch in %s", retryAfter) select { case <-time.After(retryAfter): @@ -281,7 +281,7 @@ func (c *Client) handleWatchError(ctx context.Context, err error, backoff *backo } } -func (c *Client) watchX509Context(ctx context.Context, watcher X509ContextWatcher, backoff *backoff) error { +func (c *Client) watchX509Context(ctx context.Context, watcher X509ContextWatcher, backoff Backoff) error { ctx, cancel := context.WithCancel(withHeader(ctx)) defer cancel() @@ -308,7 +308,7 @@ func (c *Client) watchX509Context(ctx context.Context, watcher X509ContextWatche } } -func (c *Client) watchJWTBundles(ctx context.Context, watcher JWTBundleWatcher, backoff *backoff) error { +func (c *Client) watchJWTBundles(ctx context.Context, watcher JWTBundleWatcher, backoff Backoff) error { ctx, cancel := context.WithCancel(withHeader(ctx)) defer cancel() @@ -335,7 +335,7 @@ func (c *Client) watchJWTBundles(ctx context.Context, watcher JWTBundleWatcher, } } -func (c *Client) watchX509Bundles(ctx context.Context, watcher X509BundleWatcher, backoff *backoff) error { +func (c *Client) watchX509Bundles(ctx context.Context, watcher X509BundleWatcher, backoff Backoff) error { ctx, cancel := context.WithCancel(withHeader(ctx)) defer cancel() @@ -402,7 +402,8 @@ func withHeader(ctx context.Context) context.Context { func defaultClientConfig() clientConfig { return clientConfig{ - log: logger.Null, + log: logger.Null, + backoffStrategy: defaultBackoffStrategy{}, } } diff --git a/v2/workloadapi/client_test.go b/v2/workloadapi/client_test.go index 5040d8f..aa2653d 100644 --- a/v2/workloadapi/client_test.go +++ b/v2/workloadapi/client_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/x509" "sync" + "sync/atomic" "testing" "time" @@ -103,7 +104,10 @@ func TestFetchX509Bundles(t *testing.T) { func TestWatchX509Bundles(t *testing.T) { wl := fakeworkloadapi.New(t) defer wl.Stop() - c, err := New(context.Background(), WithAddr(wl.Addr())) + + backoffStrategy := &testBackoffStrategy{} + + c, err := New(context.Background(), WithAddr(wl.Addr()), WithBackoffStrategy(backoffStrategy)) require.NoError(t, err) defer c.Close() @@ -149,6 +153,9 @@ func TestWatchX509Bundles(t *testing.T) { wl.Stop() tw.WaitForUpdates(1) assert.Len(t, tw.Errors(), 2) + + // Assert that there was the expected number of backoffs. + assert.Equal(t, 2, backoffStrategy.BackedOff()) } func TestFetchX509Context(t *testing.T) { @@ -213,7 +220,10 @@ func TestWatchX509Context(t *testing.T) { federatedCA := test.NewCA(t, federatedTD) wl := fakeworkloadapi.New(t) defer wl.Stop() - c, err := New(context.Background(), WithAddr(wl.Addr())) + + backoffStrategy := &testBackoffStrategy{} + + c, err := New(context.Background(), WithAddr(wl.Addr()), WithBackoffStrategy(backoffStrategy)) require.NoError(t, err) defer c.Close() @@ -291,6 +301,9 @@ func TestWatchX509Context(t *testing.T) { cancel() wg.Wait() + + // Assert that there was the expected number of backoffs. + assert.Equal(t, 2, backoffStrategy.BackedOff()) } func TestFetchJWTSVID(t *testing.T) { @@ -375,7 +388,10 @@ func TestFetchJWTBundles(t *testing.T) { func TestWatchJWTBundles(t *testing.T) { wl := fakeworkloadapi.New(t) defer wl.Stop() - c, err := New(context.Background(), WithAddr(wl.Addr())) + + backoffStrategy := &testBackoffStrategy{} + + c, err := New(context.Background(), WithAddr(wl.Addr()), WithBackoffStrategy(backoffStrategy)) require.NoError(t, err) defer c.Close() @@ -421,6 +437,9 @@ func TestWatchJWTBundles(t *testing.T) { wl.Stop() tw.WaitForUpdates(1) assert.Len(t, tw.Errors(), 2) + + // Assert that there was the expected number of backoffs. + assert.Equal(t, 2, backoffStrategy.BackedOff()) } func TestValidateJWTSVID(t *testing.T) { @@ -605,3 +624,26 @@ func (w *testWatcher) WaitForUpdates(expectedNumUpdates int) { } } } + +type testBackoffStrategy struct { + backedOff int32 +} + +func (s *testBackoffStrategy) NewBackoff() Backoff { + return testBackoff{backedOff: &s.backedOff} +} + +func (s *testBackoffStrategy) BackedOff() int { + return int(atomic.LoadInt32(&s.backedOff)) +} + +type testBackoff struct { + backedOff *int32 +} + +func (b testBackoff) Next() time.Duration { + atomic.AddInt32(b.backedOff, 1) + return time.Millisecond * 200 +} + +func (b testBackoff) Reset() {} diff --git a/v2/workloadapi/option.go b/v2/workloadapi/option.go index 00cab7d..4997b8b 100644 --- a/v2/workloadapi/option.go +++ b/v2/workloadapi/option.go @@ -35,6 +35,14 @@ func WithLogger(logger logger.Logger) ClientOption { }) } +// WithBackoff provides a custom backoff strategy that replaces the +// default backoff strategy (linear backoff). +func WithBackoffStrategy(backoffStrategy BackoffStrategy) ClientOption { + return clientOption(func(c *clientConfig) { + c.backoffStrategy = backoffStrategy + }) +} + // SourceOption are options that are shared among all option types. type SourceOption interface { configureX509Source(*x509SourceConfig) @@ -81,10 +89,11 @@ type BundleSourceOption interface { } type clientConfig struct { - address string - namedPipeName string - dialOptions []grpc.DialOption - log logger.Logger + address string + namedPipeName string + dialOptions []grpc.DialOption + log logger.Logger + backoffStrategy BackoffStrategy } type clientOption func(*clientConfig)