diff --git a/internal/broker/broker.go b/internal/broker/broker.go index bf989786..3dfd92c9 100644 --- a/internal/broker/broker.go +++ b/internal/broker/broker.go @@ -619,14 +619,12 @@ func (b *Broker) handleIsAuthenticated(ctx context.Context, session *session, au // Refresh the token if we're online even if the token has not expired if !session.isOffline { authInfo, err = b.refreshToken(ctx, session.oauth2Config, authInfo) - var retrieveErr *oauth2.RetrieveError if errors.As(err, &retrieveErr) && b.provider.IsTokenExpiredError(*retrieveErr) { // The refresh token is expired, so the user needs to authenticate via OIDC again. session.nextAuthModes = []string{authmodes.Device, authmodes.DeviceQr} return AuthNext, nil } - if err != nil { log.Error(context.Background(), err.Error()) return AuthDenied, errorMessage{Message: "could not refresh token"} diff --git a/internal/broker/broker_test.go b/internal/broker/broker_test.go index 81c44fa6..68e4788f 100644 --- a/internal/broker/broker_test.go +++ b/internal/broker/broker_test.go @@ -406,6 +406,7 @@ func TestIsAuthenticated(t *testing.T) { dontWaitForFirstCall bool readOnlyDataDir bool wantGroups []info.Group + wantNextAuthModes []string }{ "Successfully_authenticate_user_with_device_auth_and_newpassword": {firstSecret: "-", wantSecondCall: true}, "Successfully_authenticate_user_with_password": {firstMode: authmodes.Password, token: &tokenOptions{}}, @@ -457,6 +458,13 @@ func TestIsAuthenticated(t *testing.T) { token: &tokenOptions{}, useOldNameForSecretField: true, }, + "Authenticating_with_qrcode_after_refresh_token_is_expired": { + firstMode: authmodes.Password, + token: &tokenOptions{refreshTokenExpired: true}, + wantNextAuthModes: []string{authmodes.Device, authmodes.DeviceQr}, + wantSecondCall: true, + secondMode: authmodes.DeviceQr, + }, "Error_when_authentication_data_is_invalid": {invalidAuthData: true}, "Error_when_secret_can_not_be_decrypted": {firstMode: authmodes.Password, badFirstKey: true}, @@ -634,6 +642,11 @@ func TestIsAuthenticated(t *testing.T) { err = os.WriteFile(filepath.Join(outDir, "first_call"), out, 0600) require.NoError(t, err, "Failed to write first response") + if tc.wantNextAuthModes != nil { + nextAuthModes := b.GetNextAuthModes(sessionID) + require.ElementsMatch(t, tc.wantNextAuthModes, nextAuthModes, "Next auth modes should match") + } + if tc.wantGroups != nil { type userInfoMsgType struct { UserInfo info.User `json:"userinfo"` diff --git a/internal/broker/export_test.go b/internal/broker/export_test.go index 9336df33..c0be5df8 100644 --- a/internal/broker/export_test.go +++ b/internal/broker/export_test.go @@ -113,6 +113,18 @@ func (b *Broker) DataDir() string { return b.cfg.DataDir } +// GetNextAuthModes returns the next auth mode of the specified session. +func (b *Broker) GetNextAuthModes(sessionID string) []string { + b.currentSessionsMu.Lock() + defer b.currentSessionsMu.Unlock() + + session, ok := b.currentSessions[sessionID] + if !ok { + return nil + } + return session.nextAuthModes +} + // SetNextAuthModes sets the next auth mode of the specified session. func (b *Broker) SetNextAuthModes(sessionID string, authModes []string) { b.currentSessionsMu.Lock() diff --git a/internal/broker/helper_test.go b/internal/broker/helper_test.go index c17187fe..7f6c4c28 100644 --- a/internal/broker/helper_test.go +++ b/internal/broker/helper_test.go @@ -177,12 +177,13 @@ type tokenOptions struct { issuer string groups []info.Group - expired bool - noRefreshToken bool - noIDToken bool - invalid bool - invalidClaims bool - noUserInfo bool + expired bool + noRefreshToken bool + refreshTokenExpired bool + noIDToken bool + invalid bool + invalidClaims bool + noUserInfo bool } func generateCachedInfo(t *testing.T, options tokenOptions) *token.AuthCachedInfo { @@ -226,6 +227,9 @@ func generateCachedInfo(t *testing.T, options tokenOptions) *token.AuthCachedInf if options.noRefreshToken { tok.Token.RefreshToken = "" } + if options.refreshTokenExpired { + tok.Token.RefreshToken = testutils.ExpiredRefreshToken + } if !options.noUserInfo { tok.UserInfo = info.User{ diff --git a/internal/broker/testdata/golden/TestIsAuthenticated/Next_auth_mode_is_device_qr_if_refresh_token_is_expired/data/provider_url/test-user@email.com/password b/internal/broker/testdata/golden/TestIsAuthenticated/Next_auth_mode_is_device_qr_if_refresh_token_is_expired/data/provider_url/test-user@email.com/password new file mode 100644 index 00000000..11994724 --- /dev/null +++ b/internal/broker/testdata/golden/TestIsAuthenticated/Next_auth_mode_is_device_qr_if_refresh_token_is_expired/data/provider_url/test-user@email.com/password @@ -0,0 +1 @@ +Definitely a hashed password \ No newline at end of file diff --git a/internal/broker/testdata/golden/TestIsAuthenticated/Next_auth_mode_is_device_qr_if_refresh_token_is_expired/data/provider_url/test-user@email.com/token.json b/internal/broker/testdata/golden/TestIsAuthenticated/Next_auth_mode_is_device_qr_if_refresh_token_is_expired/data/provider_url/test-user@email.com/token.json new file mode 100644 index 00000000..80ab7838 --- /dev/null +++ b/internal/broker/testdata/golden/TestIsAuthenticated/Next_auth_mode_is_device_qr_if_refresh_token_is_expired/data/provider_url/test-user@email.com/token.json @@ -0,0 +1 @@ +Definitely an encrypted token \ No newline at end of file diff --git a/internal/broker/testdata/golden/TestIsAuthenticated/Next_auth_mode_is_device_qr_if_refresh_token_is_expired/first_call b/internal/broker/testdata/golden/TestIsAuthenticated/Next_auth_mode_is_device_qr_if_refresh_token_is_expired/first_call new file mode 100644 index 00000000..d0887a13 --- /dev/null +++ b/internal/broker/testdata/golden/TestIsAuthenticated/Next_auth_mode_is_device_qr_if_refresh_token_is_expired/first_call @@ -0,0 +1,3 @@ +access: next +data: '{}' +err: diff --git a/internal/broker/testdata/golden/TestIsAuthenticated/Next_auth_mode_is_device_qr_if_refresh_token_is_expired/second_call b/internal/broker/testdata/golden/TestIsAuthenticated/Next_auth_mode_is_device_qr_if_refresh_token_is_expired/second_call new file mode 100644 index 00000000..d0887a13 --- /dev/null +++ b/internal/broker/testdata/golden/TestIsAuthenticated/Next_auth_mode_is_device_qr_if_refresh_token_is_expired/second_call @@ -0,0 +1,3 @@ +access: next +data: '{}' +err: diff --git a/internal/providers/noprovider/noprovider.go b/internal/providers/noprovider/noprovider.go index cfe839b5..45f2d160 100644 --- a/internal/providers/noprovider/noprovider.go +++ b/internal/providers/noprovider/noprovider.go @@ -4,6 +4,7 @@ package noprovider import ( "context" "fmt" + "strings" "github.com/coreos/go-oidc/v3/oidc" "github.com/ubuntu/authd-oidc-brokers/internal/providers/info" @@ -103,6 +104,7 @@ func (p NoProvider) getGroups(_ *oauth2.Token) ([]info.Group, error) { // IsTokenExpiredError returns true if the reason for the error is that the refresh token is expired. func (p NoProvider) IsTokenExpiredError(err oauth2.RetrieveError) bool { - // There is no generic error for this, so we return false. - return false + // TODO: This is an msentraid specific error code and description. + // Change it to the ones from Google once we know them. + return err.ErrorCode == "invalid_grant" && strings.HasPrefix(err.ErrorDescription, "AADSTS50173:") } diff --git a/internal/testutils/provider.go b/internal/testutils/provider.go index d383e506..c6603070 100644 --- a/internal/testutils/provider.go +++ b/internal/testutils/provider.go @@ -12,6 +12,7 @@ import ( "net" "net/http" "net/http/httptest" + "net/http/httputil" "slices" "strings" "sync" @@ -23,9 +24,15 @@ import ( "github.com/ubuntu/authd-oidc-brokers/internal/consts" "github.com/ubuntu/authd-oidc-brokers/internal/providers/info" "github.com/ubuntu/authd-oidc-brokers/internal/providers/noprovider" + "github.com/ubuntu/authd/log" "golang.org/x/oauth2" ) +const ( + // ExpiredRefreshToken is used to test the expired refresh token error. + ExpiredRefreshToken = "expired-refresh-token" +) + // MockKey is the RSA key used to sign the JWTs for the mock provider. var MockKey *rsa.PrivateKey @@ -197,6 +204,22 @@ func TokenHandler(serverURL string, opts *TokenHandlerOptions) EndpointHandler { } return func(w http.ResponseWriter, r *http.Request) { + s, err := httputil.DumpRequest(r, true) + if err != nil { + log.Errorf(context.Background(), "could not dump request: %v", err) + } + log.Debugf(context.Background(), "/token endpoint request:\n%s", s) + + // Handle expired refresh token + refreshToken := r.FormValue("refresh_token") + if refreshToken == ExpiredRefreshToken { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + // This is an msentraid specific error code and description. + _, _ = w.Write([]byte(`{"error": "invalid_grant", "error_description": "AADSTS50173: The refresh token has expired."}`)) + return + } + // Mimics user going through auth process time.Sleep(2 * time.Second)