diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ca35438..1d398cd8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] - `session.CreateNewSession` now defaults to the value of the `st-auth-mode` header (if available) if the configured `config.GetTokenTransferMethod` returns `any`. +- Enable smooth switching between `useDynamicAccessTokenSigningKey` settings by allowing refresh calls to change the signing key type of a session. ## [0.17.5] - 2024-03-14 - Adds a type uint64 to the `accessTokenCookiesExpiryDurationMillis` local variable in `recipe/session/utils.go`. It also removes the redundant `uint64` type forcing needed because of the untyped variable. diff --git a/recipe/session/recipeImplementation.go b/recipe/session/recipeImplementation.go index ca8b9b0e..c6b2a464 100644 --- a/recipe/session/recipeImplementation.go +++ b/recipe/session/recipeImplementation.go @@ -294,7 +294,7 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ supertokens.LogDebugMessage("refreshSession: Started") - response, err := refreshSessionHelper(config, querier, refreshToken, antiCsrfToken, disableAntiCsrf, userContext) + response, err := refreshSessionHelper(config, querier, refreshToken, antiCsrfToken, disableAntiCsrf, config.UseDynamicAccessTokenSigningKey, userContext) if err != nil { return nil, err } diff --git a/recipe/session/sessionFunctions.go b/recipe/session/sessionFunctions.go index 43b0f858..fb7be614 100644 --- a/recipe/session/sessionFunctions.go +++ b/recipe/session/sessionFunctions.go @@ -224,10 +224,11 @@ func getSessionInformationHelper(querier supertokens.Querier, sessionHandle stri return nil, nil } -func refreshSessionHelper(config sessmodels.TypeNormalisedInput, querier supertokens.Querier, refreshToken string, antiCsrfToken *string, disableAntiCsrf bool, userContext supertokens.UserContext) (sessmodels.CreateOrRefreshAPIResponse, error) { +func refreshSessionHelper(config sessmodels.TypeNormalisedInput, querier supertokens.Querier, refreshToken string, antiCsrfToken *string, disableAntiCsrf bool, useDynamicAccessTokenSigningKey bool, userContext supertokens.UserContext) (sessmodels.CreateOrRefreshAPIResponse, error) { requestBody := map[string]interface{}{ - "refreshToken": refreshToken, - "enableAntiCsrf": !disableAntiCsrf && config.AntiCsrfFunctionOrString.StrValue == AntiCSRF_VIA_TOKEN, + "refreshToken": refreshToken, + "enableAntiCsrf": !disableAntiCsrf && config.AntiCsrfFunctionOrString.StrValue == AntiCSRF_VIA_TOKEN, + "useDynamicSigningKey": useDynamicAccessTokenSigningKey, } if antiCsrfToken != nil { requestBody["antiCsrfToken"] = *antiCsrfToken diff --git a/recipe/session/sessionHandlingFuncsWithoutReq_test.go b/recipe/session/sessionHandlingFuncsWithoutReq_test.go index e5085137..e27cfb67 100644 --- a/recipe/session/sessionHandlingFuncsWithoutReq_test.go +++ b/recipe/session/sessionHandlingFuncsWithoutReq_test.go @@ -2,6 +2,7 @@ package session import ( "errors" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -472,3 +473,87 @@ func TestRefreshShouldReturnErrorForNonTokens(t *testing.T) { assert.NotNil(t, err2) assert.True(t, errors.As(err2, &sessionError.UnauthorizedError{})) } + +func TestUseDynamicAccessTokenSigningKey(t *testing.T) { + useDynamicAccessTokenSigningKey := true + configValue := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:8080", + }, + AppInfo: supertokens.AppInfo{ + AppName: "SuperTokens", + WebsiteDomain: "supertokens.io", + APIDomain: "api.supertokens.io", + }, + RecipeList: []supertokens.Recipe{ + Init(&sessmodels.TypeInput{ + UseDynamicAccessTokenSigningKey: &useDynamicAccessTokenSigningKey, + }), + }, + } + BeforeEach() + unittesting.StartUpST("localhost", "8080") + defer AfterEach() + + checkAccessTokenSigningKeyType := func(t *testing.T, tokens sessmodels.SessionTokens, isDynamic bool) { + t.Helper() + + info, err := ParseJWTWithoutSignatureVerification(tokens.AccessToken) + assert.NoError(t, err) + + if isDynamic { + assert.True(t, strings.HasPrefix(*info.KID, "d-")) + } else { + assert.True(t, strings.HasPrefix(*info.KID, "s-")) + } + } + + err := supertokens.Init(configValue) + assert.NoError(t, err) + + res, err := CreateNewSessionWithoutRequestResponse("public", "test-user-id", map[string]interface{}{ + "tokenProp": true, + }, map[string]interface{}{ + "dbProp": true, + }, nil) + + assert.NoError(t, err) + + tokens := res.GetAllSessionTokensDangerously() + checkAccessTokenSigningKeyType(t, tokens, true) + + resetAll() + + // here we change to false + useDynamicAccessTokenSigningKey = false + err = supertokens.Init(configValue) + assert.NoError(t, err) + + t.Run("should throw when verifying", func(t *testing.T) { + _, err = GetSessionWithoutRequestResponse(tokens.AccessToken, tokens.AntiCsrfToken, nil) + assert.Equal(t, err.Error(), "The access token doesn't match the useDynamicAccessTokenSigningKey setting") + }) + + t.Run("should work after refresh", func(t *testing.T) { + disableAntiCsrf := true + refreshedSession, err := RefreshSessionWithoutRequestResponse(*tokens.RefreshToken, &disableAntiCsrf, tokens.AntiCsrfToken) + assert.NoError(t, err) + + tokensAfterRefresh := refreshedSession.GetAllSessionTokensDangerously() + assert.True(t, tokensAfterRefresh.AccessAndFrontendTokenUpdated) + checkAccessTokenSigningKeyType(t, tokensAfterRefresh, false) + + verifiedSession, err := GetSessionWithoutRequestResponse(tokensAfterRefresh.AccessToken, tokensAfterRefresh.AntiCsrfToken, nil) + assert.NoError(t, err) + + tokensAfterVerify := verifiedSession.GetAllSessionTokensDangerously() + assert.True(t, tokensAfterVerify.AccessAndFrontendTokenUpdated) + checkAccessTokenSigningKeyType(t, tokensAfterVerify, false) + + verifiedSession2, err := GetSessionWithoutRequestResponse(tokensAfterVerify.AccessToken, tokensAfterVerify.AntiCsrfToken, nil) + assert.NoError(t, err) + + tokensAfterVerify2 := verifiedSession2.GetAllSessionTokensDangerously() + assert.False(t, tokensAfterVerify2.AccessAndFrontendTokenUpdated) + }) +}