Skip to content

Commit

Permalink
Merge pull request #2692 from flant/refresh-once
Browse files Browse the repository at this point in the history
fix: refresh token only once for all concurrent requests
  • Loading branch information
sagikazarmark authored Oct 3, 2022
2 parents ffeb4d5 + 4b5f1d5 commit e4bceef
Show file tree
Hide file tree
Showing 4 changed files with 357 additions and 85 deletions.
205 changes: 124 additions & 81 deletions server/refreshhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ type refreshError struct {
desc string
}

func (r *refreshError) Error() string {
return fmt.Sprintf("refresh token error: status %d, %q %s", r.code, r.msg, r.desc)
}

func newInternalServerError() *refreshError {
return &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError}
}
Expand Down Expand Up @@ -60,10 +64,23 @@ func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.Refr
return token, nil
}

type refreshContext struct {
storageToken *storage.RefreshToken
requestToken *internal.RefreshToken

connector Connector
connectorData []byte

scopes []string
}

// getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info
func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (*storage.RefreshToken, *refreshError) {
func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (*refreshContext, *refreshError) {
refreshCtx := refreshContext{requestToken: token}

invalidErr := newBadRequestError("Refresh token is invalid or has already been claimed by another client.")

// Get RefreshToken
refresh, err := s.storage.GetRefresh(token.RefreshId)
if err != nil {
if err != storage.ErrNotFound {
Expand Down Expand Up @@ -103,7 +120,31 @@ func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.Ref
return nil, expiredErr
}

return &refresh, nil
refreshCtx.storageToken = &refresh

// Get Connector
refreshCtx.connector, err = s.getConnector(refresh.ConnectorID)
if err != nil {
s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
return nil, newInternalServerError()
}

// Get Connector Data
session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID)
switch {
case err != nil:
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err)
return nil, newInternalServerError()
}
case len(refresh.ConnectorData) > 0:
// Use the old connector data if it exists, should be deleted once used
refreshCtx.connectorData = refresh.ConnectorData
default:
refreshCtx.connectorData = session.ConnectorData
}

return &refreshCtx, nil
}

func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) {
Expand Down Expand Up @@ -138,59 +179,23 @@ func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken
return requestedScopes, nil
}

func (s *Server) refreshWithConnector(ctx context.Context, token *internal.RefreshToken, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) {
var connectorData []byte

session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID)
switch {
case err != nil:
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err)
return connector.Identity{}, newInternalServerError()
}
case len(refresh.ConnectorData) > 0:
// Use the old connector data if it exists, should be deleted once used
connectorData = refresh.ConnectorData
default:
connectorData = session.ConnectorData
}

conn, err := s.getConnector(refresh.ConnectorID)
if err != nil {
s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
return connector.Identity{}, newInternalServerError()
}

ident := connector.Identity{
UserID: refresh.Claims.UserID,
Username: refresh.Claims.Username,
PreferredUsername: refresh.Claims.PreferredUsername,
Email: refresh.Claims.Email,
EmailVerified: refresh.Claims.EmailVerified,
Groups: refresh.Claims.Groups,
ConnectorData: connectorData,
}

// user's token was previously updated by a connector and is allowed to reuse
// it is excessive to refresh identity in upstream
if s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) && token.Token == refresh.ObsoleteToken {
return ident, nil
}

func (s *Server) refreshWithConnector(ctx context.Context, rCtx *refreshContext, ident connector.Identity) (connector.Identity, *refreshError) {
// Can the connector refresh the identity? If so, attempt to refresh the data
// in the connector.
//
// TODO(ericchiang): We may want a strict mode where connectors that don't implement
// this interface can't perform refreshing.
if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok {
newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident)
if refreshConn, ok := rCtx.connector.Connector.(connector.RefreshConnector); ok {
s.logger.Debugf("connector data before refresh: %s", ident.ConnectorData)

newIdent, err := refreshConn.Refresh(ctx, parseScopes(rCtx.scopes), ident)
if err != nil {
s.logger.Errorf("failed to refresh identity: %v", err)
return connector.Identity{}, newInternalServerError()
return ident, newInternalServerError()
}
ident = newIdent
}

return newIdent, nil
}
return ident, nil
}

Expand All @@ -200,8 +205,14 @@ func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident conne
if old.Refresh[refresh.ClientID].ID != refresh.ID {
return old, errors.New("refresh token invalid")
}

old.Refresh[refresh.ClientID].LastUsed = lastUsed
old.ConnectorData = ident.ConnectorData
if len(ident.ConnectorData) > 0 {
old.ConnectorData = ident.ConnectorData
}

s.logger.Debugf("saved connector data: %s %s", ident.UserID, ident.ConnectorData)

return old, nil
}

Expand All @@ -217,33 +228,74 @@ func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident conne
}

// updateRefreshToken updates refresh token and offline session in the storage
func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *storage.RefreshToken, ident connector.Identity) (*internal.RefreshToken, *refreshError) {
newToken := token
if s.refreshTokenPolicy.RotationEnabled() {
newToken = &internal.RefreshToken{
RefreshId: refresh.ID,
Token: storage.NewID(),
}
func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (*internal.RefreshToken, connector.Identity, *refreshError) {
var rerr *refreshError

newToken := &internal.RefreshToken{
Token: rCtx.requestToken.Token,
RefreshId: rCtx.requestToken.RefreshId,
}

lastUsed := s.now()

ident := connector.Identity{
UserID: rCtx.storageToken.Claims.UserID,
Username: rCtx.storageToken.Claims.Username,
PreferredUsername: rCtx.storageToken.Claims.PreferredUsername,
Email: rCtx.storageToken.Claims.Email,
EmailVerified: rCtx.storageToken.Claims.EmailVerified,
Groups: rCtx.storageToken.Claims.Groups,
ConnectorData: rCtx.connectorData,
}

refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
if s.refreshTokenPolicy.RotationEnabled() {
if old.Token != token.Token {
if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == token.Token {
newToken.Token = old.Token
// Do not update last used time for offline session if token is allowed to be reused
lastUsed = old.LastUsed
return old, nil
}
rotationEnabled := s.refreshTokenPolicy.RotationEnabled()
reusingAllowed := s.refreshTokenPolicy.AllowedToReuse(old.LastUsed)

switch {
case !rotationEnabled && reusingAllowed:
// If rotation is disabled and the offline session was updated not so long ago - skip further actions.
return old, nil

case rotationEnabled && reusingAllowed:
if old.Token != rCtx.requestToken.Token && old.ObsoleteToken != rCtx.requestToken.Token {
return old, errors.New("refresh token claimed twice")
}

// Return previously generated token for all requests with an obsolete tokens
if old.ObsoleteToken == rCtx.requestToken.Token {
newToken.Token = old.Token
}

// Do not update last used time for offline session if token is allowed to be reused
lastUsed = old.LastUsed
ident.ConnectorData = nil
return old, nil

case rotationEnabled && !reusingAllowed:
if old.Token != rCtx.requestToken.Token {
return old, errors.New("refresh token claimed twice")
}

// Issue new refresh token
old.ObsoleteToken = old.Token
newToken.Token = storage.NewID()
}

old.Token = newToken.Token
old.LastUsed = lastUsed

// ConnectorData has been moved to OfflineSession
old.ConnectorData = []byte{}

// Call only once if there is a request which is not in the reuse interval.
// This is required to avoid multiple calls to the external IdP for concurrent requests.
// Dex will call the connector's Refresh method only once if request is not in reuse interval.
ident, rerr = s.refreshWithConnector(ctx, rCtx, ident)
if rerr != nil {
return old, rerr
}

// Update the claims of the refresh token.
//
// UserID intentionally ignored for now.
Expand All @@ -252,26 +304,23 @@ func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *stora
old.Claims.Email = ident.Email
old.Claims.EmailVerified = ident.EmailVerified
old.Claims.Groups = ident.Groups
old.LastUsed = lastUsed

// ConnectorData has been moved to OfflineSession
old.ConnectorData = []byte{}
return old, nil
}

// Update refresh token in the storage.
err := s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater)
err := s.storage.UpdateRefreshToken(rCtx.storageToken.ID, refreshTokenUpdater)
if err != nil {
s.logger.Errorf("failed to update refresh token: %v", err)
return nil, newInternalServerError()
return nil, ident, newInternalServerError()
}

rerr := s.updateOfflineSession(refresh, ident, lastUsed)
rerr = s.updateOfflineSession(rCtx.storageToken, ident, lastUsed)
if rerr != nil {
return nil, rerr
return nil, ident, rerr
}

return newToken, nil
return newToken, ident, nil
}

// handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6
Expand All @@ -283,19 +332,19 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return
}

refresh, rerr := s.getRefreshTokenFromStorage(client.ID, token)
rCtx, rerr := s.getRefreshTokenFromStorage(client.ID, token)
if rerr != nil {
s.refreshTokenErrHelper(w, rerr)
return
}

scopes, rerr := s.getRefreshScopes(r, refresh)
rCtx.scopes, rerr = s.getRefreshScopes(r, rCtx.storageToken)
if rerr != nil {
s.refreshTokenErrHelper(w, rerr)
return
}

ident, rerr := s.refreshWithConnector(r.Context(), token, refresh, scopes)
newToken, ident, rerr := s.updateRefreshToken(r.Context(), rCtx)
if rerr != nil {
s.refreshTokenErrHelper(w, rerr)
return
Expand All @@ -310,26 +359,20 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
Groups: ident.Groups,
}

accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID)
accessToken, err := s.newAccessToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.refreshTokenErrHelper(w, newInternalServerError())
return
}

idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, "", refresh.ConnectorID)
idToken, expiry, err := s.newIDToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create ID token: %v", err)
s.refreshTokenErrHelper(w, newInternalServerError())
return
}

newToken, rerr := s.updateRefreshToken(token, refresh, ident)
if rerr != nil {
s.refreshTokenErrHelper(w, rerr)
return
}

rawNewToken, err := internal.Marshal(newToken)
if err != nil {
s.logger.Errorf("failed to marshal refresh token: %v", err)
Expand Down
Loading

0 comments on commit e4bceef

Please sign in to comment.