Skip to content

Commit

Permalink
feat(authn): improve retry mechanism for fetching JWKS
Browse files Browse the repository at this point in the history
  • Loading branch information
charlie-retzler committed Jun 25, 2024
1 parent c66661e commit af360a2
Showing 1 changed file with 86 additions and 128 deletions.
214 changes: 86 additions & 128 deletions internal/authn/oidc/authn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,23 @@ var _ = Describe("authn-oidc", func() {
audience := "aud"
listenAddress := "localhost:9999"
issuerURL := "http://" + listenAddress
fakeOidcProvider, _ := newFakeOidcProvider(ProviderConfig{
IssuerURL: issuerURL,
AuthPath: "/auth",
TokenPath: "/token",
UserInfoPath: "/userInfo",
JWKSPath: "/jwks",
Algorithms: []string{"RS256", "HS256", "ES256", "PS256"},
})
var fakeOidcProvider *fakeOidcProvider

var server *httptest.Server

BeforeEach(func() {
var err error

fakeOidcProvider, err = newFakeOidcProvider(ProviderConfig{
IssuerURL: issuerURL,
AuthPath: "/auth",
TokenPath: "/token",
UserInfoPath: "/userInfo",
JWKSPath: "/jwks",
Algorithms: []string{"RS256", "HS256", "ES256", "PS256"},
})
Expect(err).To(BeNil())

server, err = fakeHttpServer(listenAddress, fakeOidcProvider.ServeHTTP)
Expect(err).To(BeNil())
})
Expand Down Expand Up @@ -81,7 +85,7 @@ var _ = Describe("authn-oidc", func() {
RefreshInterval: 5 * time.Minute,
BackoffInterval: 12 * time.Second,
BackoffMaxRetries: 5,
BackoffFrequency: 5 * time.Second,
BackoffFrequency: 1 * time.Second,
})
Expect(err).To(BeNil())

Expand Down Expand Up @@ -172,7 +176,7 @@ var _ = Describe("authn-oidc", func() {
RefreshInterval: 5 * time.Minute,
BackoffInterval: 12 * time.Second,
BackoffMaxRetries: 5,
BackoffFrequency: 5 * time.Second,
BackoffFrequency: 1 * time.Second,
})
Expect(err).To(BeNil())

Expand All @@ -181,36 +185,41 @@ var _ = Describe("authn-oidc", func() {
niceMd.Set("authorization", "Bearer "+idToken)
err = auth.Authenticate(niceMd.ToIncoming(ctx))
Expect(err != nil).To(Equal(tt.wantErr), fmt.Sprintf("Wanted error: %t, got %v", tt.wantErr, err))
Expect(time.Now()).To(BeTemporally("<=", now.Add(1*time.Second)))
}
})
})

Context("Authenticate Key Ids", func() {
It("Case 1", func() {
tests := []struct {
name string
method jwt.SigningMethod
addKeyId bool
keyId string
wantErr bool
name string
method jwt.SigningMethod
addKeyId bool
keyId string
wantErr bool
wantTiming time.Duration
}{
{
"With no keyid using RS256 it should fail, multiple public RSA keys matching for RS256 and PS256",
jwt.SigningMethodRS256,
false, "",
true,
1 * time.Second,
},
{
"With right keyid using RS256 it should pass",
jwt.SigningMethodRS256,
true, fakeOidcProvider.keyIds[jwt.SigningMethodRS256],
false,
1 * time.Second,
},
{
"With wrong keyid using RS256 it should fail",
jwt.SigningMethodRS256,
true, "wrongkeyid",
true,
6 * time.Second,
},
}

Expand Down Expand Up @@ -240,7 +249,7 @@ var _ = Describe("authn-oidc", func() {
RefreshInterval: 5 * time.Minute,
BackoffInterval: 12 * time.Second,
BackoffMaxRetries: 5,
BackoffFrequency: 5 * time.Second,
BackoffFrequency: 1 * time.Second,
})
Expect(err).To(BeNil())

Expand All @@ -249,6 +258,7 @@ var _ = Describe("authn-oidc", func() {
niceMd.Set("authorization", "Bearer "+idToken)
err = auth.Authenticate(niceMd.ToIncoming(ctx))
Expect(err != nil).To(Equal(tt.wantErr), fmt.Sprintf("Wanted error: %t, got %v", tt.wantErr, err))
Expect(time.Now()).To(BeTemporally("<=", now.Add(tt.wantTiming)))
}
})

Expand All @@ -261,7 +271,7 @@ var _ = Describe("authn-oidc", func() {
RefreshInterval: 5 * time.Minute,
BackoffInterval: 1 * time.Minute,
BackoffMaxRetries: 5,
BackoffFrequency: 12 * time.Second,
BackoffFrequency: 1 * time.Second,
})
Expect(err).To(BeNil())

Expand Down Expand Up @@ -305,6 +315,7 @@ var _ = Describe("authn-oidc", func() {
niceMd.Set("authorization", "Bearer "+idToken)
err = auth.Authenticate(niceMd.ToIncoming(ctx))
Expect(err).Should(BeNil())
Expect(time.Now()).To(BeTemporally("<=", now.Add(1*time.Second)))
}
})

Expand All @@ -315,26 +326,29 @@ var _ = Describe("authn-oidc", func() {
Audience: audience,
Issuer: issuerURL,
RefreshInterval: 5 * time.Minute,
BackoffInterval: 12 * time.Second,
BackoffMaxRetries: 5,
BackoffFrequency: 5 * time.Second,
BackoffInterval: 6 * time.Second,
BackoffMaxRetries: 3,
BackoffFrequency: 1 * time.Second,
})
Expect(err).To(BeNil())

tests := []struct {
name string
method jwt.SigningMethod
newKeyId string
name string
method jwt.SigningMethod
newKeyId string
succeedAfterRetry bool
}{
{
"Invalid KID, should retry and fail",
"Invalid KID, should retry and fail, and then succeed",
jwt.SigningMethodRS256,
"invalidkey1",
true,
},
{
"Invalid KID, should retry and fail",
jwt.SigningMethodRS256,
"invalidkey2",
false,
},
}

Expand All @@ -357,56 +371,48 @@ var _ = Describe("authn-oidc", func() {
// authenticate with retries
niceMd := make(metautils.NiceMD)
niceMd.Set("authorization", "Bearer "+idToken)
now = time.Now()
err = auth.Authenticate(niceMd.ToIncoming(ctx))
Expect(err).ShouldNot(BeNil())
}

// Wait for the backoff interval to expire
time.Sleep(13 * time.Second)
Expect(time.Now()).To(BeTemporally("<=", now.Add(4*time.Second)))

// Test with a valid KID after backoff
validKeyID := "validkey"
fakeOidcProvider.UpdateKeyID(jwt.SigningMethodRS256, validKeyID)
// authenticate after retries should fail immediately
now = time.Now()
err = auth.Authenticate(niceMd.ToIncoming(ctx))
Expect(err).ShouldNot(BeNil())
Expect(time.Now()).To(BeTemporally("<=", now.Add(1*time.Second)))

now := time.Now()
claims := jwt.RegisteredClaims{
Issuer: issuerURL,
Subject: "user",
Audience: []string{audience},
ExpiresAt: &jwt.NumericDate{Time: now.AddDate(1, 0, 0)},
IssuedAt: &jwt.NumericDate{Time: now},
}
if tt.succeedAfterRetry {
// authenticate with now valid key should succeed immediately after backoff interval elasped
fakeOidcProvider.UpdateKeyID(jwt.SigningMethodRS256, tt.newKeyId)

// create signed token from oidc provider with valid kid in header
unsignedToken := createUnsignedToken(claims, jwt.SigningMethodRS256)
unsignedToken.Header["kid"] = validKeyID
idToken, err := fakeOidcProvider.SignIDToken(unsignedToken)
Expect(err).To(BeNil())
time.Sleep(7 * time.Second)

// authenticate with valid kid
niceMd := make(metautils.NiceMD)
niceMd.Set("authorization", "Bearer "+idToken)
err = auth.Authenticate(niceMd.ToIncoming(ctx))
Expect(err).Should(BeNil())
now = time.Now()
err = auth.Authenticate(niceMd.ToIncoming(ctx))
Expect(err).Should(BeNil())
Expect(time.Now()).To(BeTemporally("<=", now.Add(1*time.Second)))
}
}
})

It("Case 4: Concurrent requests leading to global backoff lock for 1 minute", func() {
It("Case 4: Concurrent requests leading to global backoff lock for 6 seconds", func() {
// create authenticator
ctx := context.Background()
auth, err := NewOidcAuthn(ctx, config.Oidc{
Audience: audience,
Issuer: issuerURL,
RefreshInterval: 5 * time.Minute,
BackoffInterval: 12 * time.Second,
BackoffMaxRetries: 5,
BackoffFrequency: 5 * time.Second,
BackoffInterval: 6 * time.Second,
BackoffMaxRetries: 3,
BackoffFrequency: 1 * time.Second,
})
Expect(err).To(BeNil())

invalidKeyIDs := []string{"invalidkey1", "invalidkey2"}

var wg sync.WaitGroup
numRequests := len(invalidKeyIDs) * 5 // Send each invalid key multiple times to ensure retries are hit
numRequests := len(invalidKeyIDs) * 3 // Send each invalid key multiple times to ensure retries are hit

// Helper function to create a token with the specified key ID
createTokenWithKid := func(kid string) (string, error) {
Expand All @@ -427,7 +433,12 @@ var _ = Describe("authn-oidc", func() {
return idToken, nil
}

// Set valid token to ensure it's already in cache for later tests
validKeyID := "validkey"
fakeOidcProvider.UpdateKeyID(jwt.SigningMethodRS256, validKeyID)

// Step 1: Trigger backoff by hitting max retries with invalid keys concurrently
now := time.Now()
for i := 0; i < numRequests; i++ {
wg.Add(1)
go func(i int) {
Expand All @@ -436,113 +447,60 @@ var _ = Describe("authn-oidc", func() {
token, _ := createTokenWithKid(keyID)
niceMd := make(metautils.NiceMD)
niceMd.Set("authorization", "Bearer "+token)
now := time.Now()
err := auth.Authenticate(niceMd.ToIncoming(ctx))
Expect(err.Error()).Should(Equal(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String()))
Expect(time.Now()).To(BeTemporally("<=", now.Add(4*time.Second)))
}(i)
}

wg.Wait()
Expect(time.Now()).To(BeTemporally("<=", now.Add(4*time.Second)))

// Step 2: Verify that retries are immediately rejected during backoff period
for _, keyID := range invalidKeyIDs {
token, _ := createTokenWithKid(keyID)
md := metadata.Pairs("authorization", "Bearer "+token)
ctx := metadata.NewIncomingContext(ctx, md)
now := time.Now()
err := auth.Authenticate(ctx)
Expect(err.Error()).Should(Equal(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String()))
Expect(time.Now()).To(BeTemporally("<=", now.Add(1*time.Second)))
}

// Step 3: Wait for the backoff interval to expire
time.Sleep(13 * time.Second)

// Step 4: Test with a valid KID after backoff
validKeyID := "validkey"

fakeOidcProvider.UpdateKeyID(jwt.SigningMethodRS256, validKeyID)

// Step 3: A valid KID already in the JWKS cache should continue to authenticate succesfully immediately
validToken, _ := createTokenWithKid(validKeyID)
niceMd := make(metautils.NiceMD)
niceMd.Set("authorization", "Bearer "+validToken)

now = time.Now()
err = auth.Authenticate(niceMd.ToIncoming(ctx))
Expect(err).Should(BeNil())
Expect(time.Now()).To(BeTemporally("<=", now.Add(1*time.Second)))

// Step 5: Ensure that invalid keys are still rejected after backoff period
// Step 4: Ensure that invalid keys return immediately during backoff interval, even after valid KID
for _, keyID := range invalidKeyIDs {
token, _ := createTokenWithKid(keyID)
niceMd := make(metautils.NiceMD)
niceMd.Set("authorization", "Bearer "+token)
err = auth.Authenticate(niceMd.ToIncoming(ctx))
md := metadata.Pairs("authorization", "Bearer "+token)
ctx := metadata.NewIncomingContext(ctx, md)
now = time.Now()
err := auth.Authenticate(ctx)
Expect(err.Error()).Should(Equal(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String()))
Expect(time.Now()).To(BeTemporally("<=", now.Add(1*time.Second)))
}
})

It("Case 5", func() {
// create authenticator
ctx := context.Background()
auth, err := NewOidcAuthn(ctx, config.Oidc{
Audience: audience,
Issuer: issuerURL,
RefreshInterval: 5 * time.Minute,
BackoffInterval: 12 * time.Second,
BackoffMaxRetries: 5,
BackoffFrequency: 5 * time.Second,
})
Expect(err).To(BeNil())

invalidKeyID := "invalidkey"

// Helper function to create a token with the specified key ID
createTokenWithKid := func(kid string) (string, error) {
now := time.Now()
claims := jwt.RegisteredClaims{
Issuer: auth.IssuerURL,
Subject: "user",
Audience: []string{auth.Audience},
ExpiresAt: &jwt.NumericDate{Time: now.AddDate(1, 0, 0)},
IssuedAt: &jwt.NumericDate{Time: now},
}

// create signed token from oidc server with overridden claims
unsignedToken := createUnsignedToken(claims, jwt.SigningMethodRS256)
unsignedToken.Header["kid"] = kid
idToken, err := fakeOidcProvider.SignIDToken(unsignedToken)
Expect(err).To(BeNil())
return idToken, nil
}
// Step 5: Ensure that invalid keys are still rejected after backoff period
time.Sleep(7 * time.Second)

// Step 1: Trigger max retries by hitting invalid key multiple times
for i := 0; i <= auth.backoffMaxRetries; i++ {
token, _ := createTokenWithKid(invalidKeyID)
now = time.Now()
for _, keyID := range invalidKeyIDs {
token, _ := createTokenWithKid(keyID)
niceMd := make(metautils.NiceMD)
niceMd.Set("authorization", "Bearer "+token)
err = auth.Authenticate(niceMd.ToIncoming(ctx))
Expect(err.Error()).Should(Equal(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String()))
}

// Step 2: Try to fetch once after max retries reached
token, _ := createTokenWithKid(invalidKeyID)
md := metadata.Pairs("authorization", "Bearer "+token)
ctx = metadata.NewIncomingContext(ctx, md)
err = auth.Authenticate(ctx)
Expect(err.Error()).Should(Equal(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String()))

validKeyID := "validkey"

fakeOidcProvider.UpdateKeyID(jwt.SigningMethodRS256, validKeyID)

// Step 3: Try with a valid key to ensure it resets the state
validToken, _ := createTokenWithKid(validKeyID)
niceMd := make(metautils.NiceMD)
niceMd.Set("authorization", "Bearer "+validToken)
err = auth.Authenticate(niceMd.ToIncoming(ctx))
Expect(err).Should(BeNil())

// Step 4: Ensure invalid keys are still rejected after state reset
token, _ = createTokenWithKid(invalidKeyID)
md = metadata.Pairs("authorization", "Bearer "+token)
ctx = metadata.NewIncomingContext(ctx, md)
err = auth.Authenticate(ctx)
Expect(err.Error()).Should(Equal(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String()))
Expect(time.Now()).To(BeTemporally("<=", now.Add(4*time.Second)))
})
})

Expand Down

0 comments on commit af360a2

Please sign in to comment.