Skip to content

Commit

Permalink
feat: add manual linking APIs (supabase#1317)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?
* Adds a new endpoint `GET /user/identities/authorize` which is an
endpoint to initiate the manual linking process and can only be invoked
if the user is authenticated
* `GET /user/identities/authorize` functions similarly to `GET
/authorize` where the user needs to login to the new oauth identity in
order to link the identity
* Example
```curl
// sign in with one of the supported auth methods to get the user's access token JWT first

// start the identity linking process
$ curl -X GET "http://localhost:9999/user/identities/authorize?provider=google" -H "Authorization: Bearer ACCESS_TOKEN_JWT"

{"url":"https://oauth_provider_url.com/path/to/sign-in"}

// visit the url returned and login to the oauth provider 
// request will be redirected to the /callback endpoint

// if the identity is successfully linked, the request will be redirected to `http://localhost:3000/#access_token=xxx&....`

// if the identity already exists, the request will be redirect to:
// http://localhost:3000/?error=invalid_request&error_code=400&error_description=Identity+is+already+linked+to+another+user#error=invalid_request&error_code=400&error_description=Identity+is+already+linked+to+another+user
```

## Details
* The callback endpoint used will be the same callback as the oauth
sign-in flow so that the developer doesn't have to add any additional
callback URLs to the oauth provider in order to enable manual linking
* A special field `LinkingTargetId` is introduced in the oauth state to
store the linking target user ID. This ID will be used in the callback
to determine the target user to link the candidate identity used
* If the identity is already linked to the current user or another user,
an error will be returned
* If the identity doesn't exist, then it will be successfully linked to
the existing user and a new access & refresh token will be issued.

---------

Co-authored-by: Stojan Dimitrovski <[email protected]>
  • Loading branch information
kangmingtay and hf authored Nov 29, 2023
1 parent 9b7d7f6 commit e39ce7d
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 50 deletions.
6 changes: 5 additions & 1 deletion internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,11 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
r.Get("/", api.UserGet)
r.With(sharedLimiter).Put("/", api.UserUpdate)

r.Delete("/identities/{identity_id}", api.DeleteIdentity)
r.Route("/identities", func(r *router) {
r.Use(api.requireManualLinkingEnabled)
r.Get("/authorize", api.LinkIdentity)
r.Delete("/{identity_id}", api.DeleteIdentity)
})
})

r.With(api.requireAuthentication).Route("/factors", func(r *router) {
Expand Down
18 changes: 18 additions & 0 deletions internal/api/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const (
signatureKey = contextKey("signature")
externalProviderTypeKey = contextKey("external_provider_type")
userKey = contextKey("user")
targetUserKey = contextKey("target_user")
factorKey = contextKey("factor")
sessionKey = contextKey("session")
externalReferrerKey = contextKey("external_referrer")
Expand Down Expand Up @@ -76,6 +77,11 @@ func withUser(ctx context.Context, u *models.User) context.Context {
return context.WithValue(ctx, userKey, u)
}

// withTargetUser adds the target user for linking to the context.
func withTargetUser(ctx context.Context, u *models.User) context.Context {
return context.WithValue(ctx, targetUserKey, u)
}

// with Factor adds the factor id to the context.
func withFactor(ctx context.Context, f *models.Factor) context.Context {
return context.WithValue(ctx, factorKey, f)
Expand All @@ -93,6 +99,18 @@ func getUser(ctx context.Context) *models.User {
return obj.(*models.User)
}

// getTargetUser reads the user from the context.
func getTargetUser(ctx context.Context) *models.User {
if ctx == nil {
return nil
}
obj := ctx.Value(targetUserKey)
if obj == nil {
return nil
}
return obj.(*models.User)
}

// getFactor reads the factor id from the context
func getFactor(ctx context.Context) *models.Factor {
obj := ctx.Value(factorKey)
Expand Down
130 changes: 82 additions & 48 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,25 @@ import (
// ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow
type ExternalProviderClaims struct {
AuthMicroserviceClaims
Provider string `json:"provider"`
InviteToken string `json:"invite_token,omitempty"`
Referrer string `json:"referrer,omitempty"`
FlowStateID string `json:"flow_state_id"`
Provider string `json:"provider"`
InviteToken string `json:"invite_token,omitempty"`
Referrer string `json:"referrer,omitempty"`
FlowStateID string `json:"flow_state_id"`
LinkingTargetID string `json:"linking_target_id,omitempty"`
}

// ExternalProviderRedirect redirects the request to the corresponding oauth provider
// ExternalProviderRedirect redirects the request to the oauth provider
func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) error {
rurl, err := a.GetExternalProviderRedirectURL(w, r, nil)
if err != nil {
return err
}
http.Redirect(w, r, rurl, http.StatusFound)
return nil
}

// GetExternalProviderRedirectURL returns the URL to start the oauth flow with the corresponding oauth provider
func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Request, linkingTargetUser *models.User) (string, error) {
ctx := r.Context()
db := a.db.WithContext(ctx)
config := a.config
Expand All @@ -45,45 +56,45 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e

p, err := a.Provider(ctx, providerType, scopes)
if err != nil {
return badRequestError("Unsupported provider: %+v", err).WithInternalError(err)
return "", badRequestError("Unsupported provider: %+v", err).WithInternalError(err)
}

inviteToken := query.Get("invite_token")
if inviteToken != "" {
_, userErr := models.FindUserByConfirmationToken(db, inviteToken)
if userErr != nil {
if models.IsNotFoundError(userErr) {
return notFoundError(userErr.Error())
return "", notFoundError(userErr.Error())
}
return internalServerError("Database error finding user").WithInternalError(userErr)
return "", internalServerError("Database error finding user").WithInternalError(userErr)
}
}

redirectURL := utilities.GetReferrer(r, config)
log := observability.GetLogEntry(r)
log.WithField("provider", providerType).Info("Redirecting to external provider")
if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil {
return err
return "", err
}
flowType := getFlowFromChallenge(codeChallenge)

flowStateID := ""
if flowType == models.PKCEFlow {
codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod)
if err != nil {
return err
return "", err
}
flowState, err := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth)
if err != nil {
return err
return "", err
}
if err := a.db.Create(flowState); err != nil {
return err
return "", err
}
flowStateID = flowState.ID.String()
}

token := jwt.NewWithClaims(jwt.SigningMethodHS256, ExternalProviderClaims{
claims := ExternalProviderClaims{
AuthMicroserviceClaims: AuthMicroserviceClaims{
StandardClaims: jwt.StandardClaims{
ExpiresAt: time.Now().Add(5 * time.Minute).Unix(),
Expand All @@ -95,10 +106,17 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e
InviteToken: inviteToken,
Referrer: redirectURL,
FlowStateID: flowStateID,
})
}

if linkingTargetUser != nil {
// this means that the user is performing manual linking
claims.LinkingTargetID = linkingTargetUser.ID.String()
}

token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(config.JWT.Secret))
if err != nil {
return internalServerError("Error creating state").WithInternalError(err)
return "", internalServerError("Error creating state").WithInternalError(err)
}

authUrlParams := make([]oauth2.AuthCodeOption, 0)
Expand All @@ -115,20 +133,15 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e
}
}

var authURL string
authURL := p.AuthCodeURL(tokenString, authUrlParams...)
switch externalProvider := p.(type) {
case *provider.TwitterProvider:
authURL = externalProvider.AuthCodeURL(tokenString, authUrlParams...)
err := storage.StoreInSession(providerType, externalProvider.Marshal(), r, w)
if err != nil {
return internalServerError("Error storing request token in session").WithInternalError(err)
if err := storage.StoreInSession(providerType, externalProvider.Marshal(), r, w); err != nil {
return "", internalServerError("Error storing request token in session").WithInternalError(err)
}
default:
authURL = p.AuthCodeURL(tokenString, authUrlParams...)
}

http.Redirect(w, r, authURL, http.StatusFound)
return nil
return authURL, nil
}

// ExternalProviderCallback handles the callback endpoint in the external oauth provider flow
Expand All @@ -142,37 +155,41 @@ func (a *API) ExternalProviderCallback(w http.ResponseWriter, r *http.Request) e
return nil
}

func (a *API) handleOAuthCallback(w http.ResponseWriter, r *http.Request) (*OAuthProviderData, error) {
ctx := r.Context()
providerType := getExternalProviderType(ctx)

var oAuthResponseData *OAuthProviderData
var err error
switch providerType {
case "twitter":
// future OAuth1.0 providers will use this method
oAuthResponseData, err = a.oAuth1Callback(ctx, r, providerType)
default:
oAuthResponseData, err = a.oAuthCallback(ctx, r, providerType)
}
if err != nil {
return nil, err
}
return oAuthResponseData, nil
}

func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)
config := a.config

providerType := getExternalProviderType(ctx)
var userData *provider.UserProvidedData
var providerAccessToken string
var providerRefreshToken string
var grantParams models.GrantParams
var err error

grantParams.FillGrantParams(r)

if providerType == "twitter" {
// future OAuth1.0 providers will use this method
oAuthResponseData, err := a.oAuth1Callback(ctx, r, providerType)
if err != nil {
return err
}
userData = oAuthResponseData.userData
providerAccessToken = oAuthResponseData.token
} else {
oAuthResponseData, err := a.oAuthCallback(ctx, r, providerType)
if err != nil {
return err
}
userData = oAuthResponseData.userData
providerAccessToken = oAuthResponseData.token
providerRefreshToken = oAuthResponseData.refreshToken
providerType := getExternalProviderType(ctx)
data, err := a.handleOAuthCallback(w, r)
if err != nil {
return err
}
userData := data.userData
providerAccessToken := data.token
providerRefreshToken := data.refreshToken

var flowState *models.FlowState
// if there's a non-empty FlowStateID we perform PKCE Flow
Expand All @@ -187,8 +204,11 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
var token *AccessTokenResponse
err = db.Transaction(func(tx *storage.Connection) error {
var terr error
inviteToken := getInviteToken(ctx)
if inviteToken != "" {
if targetUser := getTargetUser(ctx); targetUser != nil {
if user, terr = a.linkIdentityToUser(ctx, tx, userData, providerType); terr != nil {
return terr
}
} else if inviteToken := getInviteToken(ctx); inviteToken != "" {
if user, terr = a.processInvite(r, ctx, tx, userData, inviteToken, providerType); terr != nil {
return terr
}
Expand Down Expand Up @@ -479,6 +499,20 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont
if claims.FlowStateID != "" {
ctx = withFlowStateID(ctx, claims.FlowStateID)
}
if claims.LinkingTargetID != "" {
linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID)
if err != nil {
return nil, badRequestError("invalid target user id")
}
u, err := models.FindUserByID(a.db, linkingTargetUserID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, notFoundError("Linking target user not found")
}
return nil, internalServerError("Database error loading user").WithInternalError(err)
}
ctx = withTargetUser(ctx, u)
}
ctx = withExternalProviderType(ctx, claims.Provider)
return withSignature(ctx, state), nil
}
Expand Down
45 changes: 44 additions & 1 deletion internal/api/identity.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package api

import (
"context"
"net/http"

"github.com/fatih/structs"
"github.com/go-chi/chi"
"github.com/gofrs/uuid"
"github.com/supabase/gotrue/internal/api/provider"
"github.com/supabase/gotrue/internal/models"
"github.com/supabase/gotrue/internal/storage"
)
Expand All @@ -29,7 +32,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error {

user := getUser(ctx)
if len(user.Identities) <= 1 {
return badRequestError("Cannot unlink identity from user. User must have at least 1 identity after unlinking")
return badRequestError("User must have at least 1 identity after unlinking")
}
var identityToBeDeleted *models.Identity
for i := range user.Identities {
Expand Down Expand Up @@ -65,3 +68,43 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error {

return sendJSON(w, http.StatusOK, map[string]interface{}{})
}

func (a *API) LinkIdentity(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
user := getUser(ctx)
rurl, err := a.GetExternalProviderRedirectURL(w, r, user)
if err != nil {
return err
}
skipHTTPRedirect := r.URL.Query().Get("skip_http_redirect") == "true"
if skipHTTPRedirect {
return sendJSON(w, http.StatusOK, map[string]interface{}{
"url": rurl,
})
}
http.Redirect(w, r, rurl, http.StatusFound)
return nil
}

func (a *API) linkIdentityToUser(ctx context.Context, tx *storage.Connection, userData *provider.UserProvidedData, providerType string) (*models.User, error) {
targetUser := getTargetUser(ctx)
identity, terr := models.FindIdentityByIdAndProvider(tx, userData.Metadata.Subject, providerType)
if terr != nil {
if !models.IsNotFoundError(terr) {
return nil, internalServerError("Database error finding identity for linking").WithInternalError(terr)
}
}
if identity != nil {
if identity.UserID == targetUser.ID {
return nil, badRequestError("Identity is already linked")
}
return nil, badRequestError("Identity is already linked to another user")
}
if _, terr := a.createNewIdentity(tx, targetUser, providerType, structs.Map(userData.Metadata)); terr != nil {
return nil, terr
}
if terr := targetUser.UpdateAppMetaDataProviders(tx); terr != nil {
return nil, terr
}
return targetUser, nil
}
Loading

0 comments on commit e39ce7d

Please sign in to comment.