Skip to content

Commit

Permalink
Merge pull request #1319 from charlie-retzler/improve-jwks-refresh-re…
Browse files Browse the repository at this point in the history
…try-mechanism

feat(authn): improve jwks refresh retry mechanism
  • Loading branch information
tolgaOzen authored Jun 28, 2024
2 parents f45b5ea + af360a2 commit 2f1c8e5
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 147 deletions.
64 changes: 45 additions & 19 deletions internal/authn/oidc/authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ type Authn struct {
backoffFrequency time.Duration

// Global backoff state
globalRetryCount int
globalFirstSeen time.Time
mu sync.Mutex
globalRetryCount int
globalFirstSeen time.Time
globalRetryKeyIds map[string]bool
mu sync.Mutex
}

// NewOidcAuthn initializes a new instance of the Authn struct with OpenID Connect (OIDC) configuration.
Expand Down Expand Up @@ -98,6 +99,7 @@ func NewOidcAuthn(ctx context.Context, conf config.Oidc) (*Authn, error) {
backoffMaxRetries: backoffMaxRetries,
backoffFrequency: backoffFrequency,
globalRetryCount: 0,
globalRetryKeyIds: make(map[string]bool),
globalFirstSeen: time.Time{},
mu: sync.Mutex{},
}
Expand Down Expand Up @@ -203,6 +205,7 @@ func (oidc *Authn) getKeyWithRetry(keyID string, ctx context.Context) (interface
slog.Info("resetting state as interval has passed or first seen is zero", "keyID", keyID)
oidc.globalFirstSeen = now
oidc.globalRetryCount = 0
oidc.globalRetryKeyIds = make(map[string]bool)
} else if oidc.globalRetryCount >= oidc.backoffMaxRetries {
// If max retries reached within the interval, unlock and check keyID once
slog.Warn("max retries reached within interval, will check keyID once", "keyID", keyID)
Expand All @@ -211,11 +214,16 @@ func (oidc *Authn) getKeyWithRetry(keyID string, ctx context.Context) (interface
// Try to fetch the keyID once
rawKey, err = oidc.fetchKey(keyID, ctx)
if err == nil {
// Reset global backoff state if a valid key is found
slog.Info("valid key found during backoff period, resetting state", "keyID", keyID)
oidc.mu.Lock()
oidc.globalRetryCount = 0
oidc.globalFirstSeen = time.Time{}
if _, ok := oidc.globalRetryKeyIds[keyID]; ok {
// Reset global backoff state if a valid key is found and that key had been retried.
// Use case would be someone trying to exploit with bad KeyIDs, and along comes a valid KeyID
// The valid KeyID should not reset the counters for a bad key
slog.Info("valid key found during backoff period, resetting state", "keyID", keyID)
oidc.globalRetryCount = 0
oidc.globalFirstSeen = time.Time{}
oidc.globalRetryKeyIds = make(map[string]bool)
}
oidc.mu.Unlock()
return rawKey, nil
}
Expand All @@ -229,6 +237,26 @@ func (oidc *Authn) getKeyWithRetry(keyID string, ctx context.Context) (interface
// Retry mechanism
retries := 0
for retries <= oidc.backoffMaxRetries {
rawKey, err = oidc.fetchKey(keyID, ctx)
if err == nil {
if retries != 0 {
oidc.mu.Lock()
oidc.globalRetryCount = 0
oidc.globalFirstSeen = time.Time{}
oidc.globalRetryKeyIds = make(map[string]bool)
oidc.mu.Unlock()
}
return rawKey, nil
}
oidc.mu.Lock()
initialGlobalRetryCount := oidc.globalRetryCount
oidc.globalRetryKeyIds[keyID] = true
if oidc.globalRetryCount > oidc.backoffMaxRetries {
slog.Error("key ID not found in JWKS due to global retries", "keyID", keyID, "globalRetryCount", oidc.globalRetryCount)
oidc.mu.Unlock()
return nil, errors.New("too many attempts, backoff in effect due to global retry count")
}
oidc.mu.Unlock()
if retries > 0 {
select {
case <-time.After(oidc.backoffFrequency):
Expand All @@ -240,28 +268,26 @@ func (oidc *Authn) getKeyWithRetry(keyID string, ctx context.Context) (interface
}
}

rawKey, err = oidc.fetchKey(keyID, ctx)
if err == nil {
// Log the successful key fetch and reset global state
slog.Info("successfully fetched key", "keyID", keyID)
oidc.mu.Lock()
oidc.globalRetryCount = 0
oidc.globalFirstSeen = time.Time{}
oidc.mu.Lock()
if oidc.globalRetryCount > initialGlobalRetryCount {
// Concurrent requests in retry loop at same time, another concurrent request already refreshed the JWKS
retries++
slog.Warn("another concurrent request already refreshed the JWKS")
oidc.mu.Unlock()
return rawKey, nil
continue
}

oidc.globalRetryCount++
slog.Warn("retrying to fetch JWKS due to error", "keyID", keyID, "retries", retries, "error", err)
retries++

oidc.mu.Lock()
oidc.globalRetryCount++
oidc.mu.Unlock()

if _, refreshErr := oidc.jwksSet.Refresh(ctx, oidc.JwksURI); refreshErr != nil {
oidc.mu.Unlock()
slog.Error("failed to refresh JWKS", "error", refreshErr)
return nil, refreshErr
}
// Unlock needs to follow Refresh to ensure that concurrent requests don't make duplicate calls to Refresh
oidc.mu.Unlock()
}

// Mark the global state to prevent further retries for the backoff interval
Expand Down
Loading

0 comments on commit 2f1c8e5

Please sign in to comment.