From af360a20f0e2f14a42fff8f973295c0dfd5b380d Mon Sep 17 00:00:00 2001 From: Charlie Retzler Date: Tue, 25 Jun 2024 10:17:48 -0400 Subject: [PATCH] feat(authn): improve retry mechanism for fetching JWKS --- internal/authn/oidc/authn_test.go | 214 ++++++++++++------------------ 1 file changed, 86 insertions(+), 128 deletions(-) diff --git a/internal/authn/oidc/authn_test.go b/internal/authn/oidc/authn_test.go index 6e54fdf58..f83f6f665 100644 --- a/internal/authn/oidc/authn_test.go +++ b/internal/authn/oidc/authn_test.go @@ -22,19 +22,23 @@ var _ = Describe("authn-oidc", func() { audience := "aud" listenAddress := "localhost:9999" issuerURL := "http://" + listenAddress - fakeOidcProvider, _ := newFakeOidcProvider(ProviderConfig{ - IssuerURL: issuerURL, - AuthPath: "/auth", - TokenPath: "/token", - UserInfoPath: "/userInfo", - JWKSPath: "/jwks", - Algorithms: []string{"RS256", "HS256", "ES256", "PS256"}, - }) + var fakeOidcProvider *fakeOidcProvider var server *httptest.Server BeforeEach(func() { var err error + + fakeOidcProvider, err = newFakeOidcProvider(ProviderConfig{ + IssuerURL: issuerURL, + AuthPath: "/auth", + TokenPath: "/token", + UserInfoPath: "/userInfo", + JWKSPath: "/jwks", + Algorithms: []string{"RS256", "HS256", "ES256", "PS256"}, + }) + Expect(err).To(BeNil()) + server, err = fakeHttpServer(listenAddress, fakeOidcProvider.ServeHTTP) Expect(err).To(BeNil()) }) @@ -81,7 +85,7 @@ var _ = Describe("authn-oidc", func() { RefreshInterval: 5 * time.Minute, BackoffInterval: 12 * time.Second, BackoffMaxRetries: 5, - BackoffFrequency: 5 * time.Second, + BackoffFrequency: 1 * time.Second, }) Expect(err).To(BeNil()) @@ -172,7 +176,7 @@ var _ = Describe("authn-oidc", func() { RefreshInterval: 5 * time.Minute, BackoffInterval: 12 * time.Second, BackoffMaxRetries: 5, - BackoffFrequency: 5 * time.Second, + BackoffFrequency: 1 * time.Second, }) Expect(err).To(BeNil()) @@ -181,6 +185,7 @@ var _ = Describe("authn-oidc", func() { niceMd.Set("authorization", "Bearer "+idToken) err = auth.Authenticate(niceMd.ToIncoming(ctx)) Expect(err != nil).To(Equal(tt.wantErr), fmt.Sprintf("Wanted error: %t, got %v", tt.wantErr, err)) + Expect(time.Now()).To(BeTemporally("<=", now.Add(1*time.Second))) } }) }) @@ -188,29 +193,33 @@ var _ = Describe("authn-oidc", func() { Context("Authenticate Key Ids", func() { It("Case 1", func() { tests := []struct { - name string - method jwt.SigningMethod - addKeyId bool - keyId string - wantErr bool + name string + method jwt.SigningMethod + addKeyId bool + keyId string + wantErr bool + wantTiming time.Duration }{ { "With no keyid using RS256 it should fail, multiple public RSA keys matching for RS256 and PS256", jwt.SigningMethodRS256, false, "", true, + 1 * time.Second, }, { "With right keyid using RS256 it should pass", jwt.SigningMethodRS256, true, fakeOidcProvider.keyIds[jwt.SigningMethodRS256], false, + 1 * time.Second, }, { "With wrong keyid using RS256 it should fail", jwt.SigningMethodRS256, true, "wrongkeyid", true, + 6 * time.Second, }, } @@ -240,7 +249,7 @@ var _ = Describe("authn-oidc", func() { RefreshInterval: 5 * time.Minute, BackoffInterval: 12 * time.Second, BackoffMaxRetries: 5, - BackoffFrequency: 5 * time.Second, + BackoffFrequency: 1 * time.Second, }) Expect(err).To(BeNil()) @@ -249,6 +258,7 @@ var _ = Describe("authn-oidc", func() { niceMd.Set("authorization", "Bearer "+idToken) err = auth.Authenticate(niceMd.ToIncoming(ctx)) Expect(err != nil).To(Equal(tt.wantErr), fmt.Sprintf("Wanted error: %t, got %v", tt.wantErr, err)) + Expect(time.Now()).To(BeTemporally("<=", now.Add(tt.wantTiming))) } }) @@ -261,7 +271,7 @@ var _ = Describe("authn-oidc", func() { RefreshInterval: 5 * time.Minute, BackoffInterval: 1 * time.Minute, BackoffMaxRetries: 5, - BackoffFrequency: 12 * time.Second, + BackoffFrequency: 1 * time.Second, }) Expect(err).To(BeNil()) @@ -305,6 +315,7 @@ var _ = Describe("authn-oidc", func() { niceMd.Set("authorization", "Bearer "+idToken) err = auth.Authenticate(niceMd.ToIncoming(ctx)) Expect(err).Should(BeNil()) + Expect(time.Now()).To(BeTemporally("<=", now.Add(1*time.Second))) } }) @@ -315,26 +326,29 @@ var _ = Describe("authn-oidc", func() { Audience: audience, Issuer: issuerURL, RefreshInterval: 5 * time.Minute, - BackoffInterval: 12 * time.Second, - BackoffMaxRetries: 5, - BackoffFrequency: 5 * time.Second, + BackoffInterval: 6 * time.Second, + BackoffMaxRetries: 3, + BackoffFrequency: 1 * time.Second, }) Expect(err).To(BeNil()) tests := []struct { - name string - method jwt.SigningMethod - newKeyId string + name string + method jwt.SigningMethod + newKeyId string + succeedAfterRetry bool }{ { - "Invalid KID, should retry and fail", + "Invalid KID, should retry and fail, and then succeed", jwt.SigningMethodRS256, "invalidkey1", + true, }, { "Invalid KID, should retry and fail", jwt.SigningMethodRS256, "invalidkey2", + false, }, } @@ -357,56 +371,48 @@ var _ = Describe("authn-oidc", func() { // authenticate with retries niceMd := make(metautils.NiceMD) niceMd.Set("authorization", "Bearer "+idToken) + now = time.Now() err = auth.Authenticate(niceMd.ToIncoming(ctx)) Expect(err).ShouldNot(BeNil()) - } - - // Wait for the backoff interval to expire - time.Sleep(13 * time.Second) + Expect(time.Now()).To(BeTemporally("<=", now.Add(4*time.Second))) - // Test with a valid KID after backoff - validKeyID := "validkey" - fakeOidcProvider.UpdateKeyID(jwt.SigningMethodRS256, validKeyID) + // authenticate after retries should fail immediately + now = time.Now() + err = auth.Authenticate(niceMd.ToIncoming(ctx)) + Expect(err).ShouldNot(BeNil()) + Expect(time.Now()).To(BeTemporally("<=", now.Add(1*time.Second))) - now := time.Now() - claims := jwt.RegisteredClaims{ - Issuer: issuerURL, - Subject: "user", - Audience: []string{audience}, - ExpiresAt: &jwt.NumericDate{Time: now.AddDate(1, 0, 0)}, - IssuedAt: &jwt.NumericDate{Time: now}, - } + if tt.succeedAfterRetry { + // authenticate with now valid key should succeed immediately after backoff interval elasped + fakeOidcProvider.UpdateKeyID(jwt.SigningMethodRS256, tt.newKeyId) - // create signed token from oidc provider with valid kid in header - unsignedToken := createUnsignedToken(claims, jwt.SigningMethodRS256) - unsignedToken.Header["kid"] = validKeyID - idToken, err := fakeOidcProvider.SignIDToken(unsignedToken) - Expect(err).To(BeNil()) + time.Sleep(7 * time.Second) - // authenticate with valid kid - niceMd := make(metautils.NiceMD) - niceMd.Set("authorization", "Bearer "+idToken) - err = auth.Authenticate(niceMd.ToIncoming(ctx)) - Expect(err).Should(BeNil()) + now = time.Now() + err = auth.Authenticate(niceMd.ToIncoming(ctx)) + Expect(err).Should(BeNil()) + Expect(time.Now()).To(BeTemporally("<=", now.Add(1*time.Second))) + } + } }) - It("Case 4: Concurrent requests leading to global backoff lock for 1 minute", func() { + It("Case 4: Concurrent requests leading to global backoff lock for 6 seconds", func() { // create authenticator ctx := context.Background() auth, err := NewOidcAuthn(ctx, config.Oidc{ Audience: audience, Issuer: issuerURL, RefreshInterval: 5 * time.Minute, - BackoffInterval: 12 * time.Second, - BackoffMaxRetries: 5, - BackoffFrequency: 5 * time.Second, + BackoffInterval: 6 * time.Second, + BackoffMaxRetries: 3, + BackoffFrequency: 1 * time.Second, }) Expect(err).To(BeNil()) invalidKeyIDs := []string{"invalidkey1", "invalidkey2"} var wg sync.WaitGroup - numRequests := len(invalidKeyIDs) * 5 // Send each invalid key multiple times to ensure retries are hit + numRequests := len(invalidKeyIDs) * 3 // Send each invalid key multiple times to ensure retries are hit // Helper function to create a token with the specified key ID createTokenWithKid := func(kid string) (string, error) { @@ -427,7 +433,12 @@ var _ = Describe("authn-oidc", func() { return idToken, nil } + // Set valid token to ensure it's already in cache for later tests + validKeyID := "validkey" + fakeOidcProvider.UpdateKeyID(jwt.SigningMethodRS256, validKeyID) + // Step 1: Trigger backoff by hitting max retries with invalid keys concurrently + now := time.Now() for i := 0; i < numRequests; i++ { wg.Add(1) go func(i int) { @@ -436,113 +447,60 @@ var _ = Describe("authn-oidc", func() { token, _ := createTokenWithKid(keyID) niceMd := make(metautils.NiceMD) niceMd.Set("authorization", "Bearer "+token) + now := time.Now() err := auth.Authenticate(niceMd.ToIncoming(ctx)) Expect(err.Error()).Should(Equal(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())) + Expect(time.Now()).To(BeTemporally("<=", now.Add(4*time.Second))) }(i) } wg.Wait() + Expect(time.Now()).To(BeTemporally("<=", now.Add(4*time.Second))) // Step 2: Verify that retries are immediately rejected during backoff period for _, keyID := range invalidKeyIDs { token, _ := createTokenWithKid(keyID) md := metadata.Pairs("authorization", "Bearer "+token) ctx := metadata.NewIncomingContext(ctx, md) + now := time.Now() err := auth.Authenticate(ctx) Expect(err.Error()).Should(Equal(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())) + Expect(time.Now()).To(BeTemporally("<=", now.Add(1*time.Second))) } - // Step 3: Wait for the backoff interval to expire - time.Sleep(13 * time.Second) - - // Step 4: Test with a valid KID after backoff - validKeyID := "validkey" - - fakeOidcProvider.UpdateKeyID(jwt.SigningMethodRS256, validKeyID) - + // Step 3: A valid KID already in the JWKS cache should continue to authenticate succesfully immediately validToken, _ := createTokenWithKid(validKeyID) niceMd := make(metautils.NiceMD) niceMd.Set("authorization", "Bearer "+validToken) + + now = time.Now() err = auth.Authenticate(niceMd.ToIncoming(ctx)) Expect(err).Should(BeNil()) + Expect(time.Now()).To(BeTemporally("<=", now.Add(1*time.Second))) - // Step 5: Ensure that invalid keys are still rejected after backoff period + // Step 4: Ensure that invalid keys return immediately during backoff interval, even after valid KID for _, keyID := range invalidKeyIDs { token, _ := createTokenWithKid(keyID) - niceMd := make(metautils.NiceMD) - niceMd.Set("authorization", "Bearer "+token) - err = auth.Authenticate(niceMd.ToIncoming(ctx)) + md := metadata.Pairs("authorization", "Bearer "+token) + ctx := metadata.NewIncomingContext(ctx, md) + now = time.Now() + err := auth.Authenticate(ctx) Expect(err.Error()).Should(Equal(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())) + Expect(time.Now()).To(BeTemporally("<=", now.Add(1*time.Second))) } - }) - - It("Case 5", func() { - // create authenticator - ctx := context.Background() - auth, err := NewOidcAuthn(ctx, config.Oidc{ - Audience: audience, - Issuer: issuerURL, - RefreshInterval: 5 * time.Minute, - BackoffInterval: 12 * time.Second, - BackoffMaxRetries: 5, - BackoffFrequency: 5 * time.Second, - }) - Expect(err).To(BeNil()) - - invalidKeyID := "invalidkey" - - // Helper function to create a token with the specified key ID - createTokenWithKid := func(kid string) (string, error) { - now := time.Now() - claims := jwt.RegisteredClaims{ - Issuer: auth.IssuerURL, - Subject: "user", - Audience: []string{auth.Audience}, - ExpiresAt: &jwt.NumericDate{Time: now.AddDate(1, 0, 0)}, - IssuedAt: &jwt.NumericDate{Time: now}, - } - // create signed token from oidc server with overridden claims - unsignedToken := createUnsignedToken(claims, jwt.SigningMethodRS256) - unsignedToken.Header["kid"] = kid - idToken, err := fakeOidcProvider.SignIDToken(unsignedToken) - Expect(err).To(BeNil()) - return idToken, nil - } + // Step 5: Ensure that invalid keys are still rejected after backoff period + time.Sleep(7 * time.Second) - // Step 1: Trigger max retries by hitting invalid key multiple times - for i := 0; i <= auth.backoffMaxRetries; i++ { - token, _ := createTokenWithKid(invalidKeyID) + now = time.Now() + for _, keyID := range invalidKeyIDs { + token, _ := createTokenWithKid(keyID) niceMd := make(metautils.NiceMD) niceMd.Set("authorization", "Bearer "+token) err = auth.Authenticate(niceMd.ToIncoming(ctx)) Expect(err.Error()).Should(Equal(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())) } - - // Step 2: Try to fetch once after max retries reached - token, _ := createTokenWithKid(invalidKeyID) - md := metadata.Pairs("authorization", "Bearer "+token) - ctx = metadata.NewIncomingContext(ctx, md) - err = auth.Authenticate(ctx) - Expect(err.Error()).Should(Equal(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())) - - validKeyID := "validkey" - - fakeOidcProvider.UpdateKeyID(jwt.SigningMethodRS256, validKeyID) - - // Step 3: Try with a valid key to ensure it resets the state - validToken, _ := createTokenWithKid(validKeyID) - niceMd := make(metautils.NiceMD) - niceMd.Set("authorization", "Bearer "+validToken) - err = auth.Authenticate(niceMd.ToIncoming(ctx)) - Expect(err).Should(BeNil()) - - // Step 4: Ensure invalid keys are still rejected after state reset - token, _ = createTokenWithKid(invalidKeyID) - md = metadata.Pairs("authorization", "Bearer "+token) - ctx = metadata.NewIncomingContext(ctx, md) - err = auth.Authenticate(ctx) - Expect(err.Error()).Should(Equal(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())) + Expect(time.Now()).To(BeTemporally("<=", now.Add(4*time.Second))) }) })