Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add manual linking APIs #1317

Merged
merged 14 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,11 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
r.Get("/", api.UserGet)
r.With(sharedLimiter).Put("/", api.UserUpdate)

r.Delete("/identities/{identity_id}", api.DeleteIdentity)
r.Route("/identities", func(r *router) {
r.Use(api.requireManualLinkingEnabled)
r.Get("/authorize", api.LinkIdentity)
r.Delete("/{identity_id}", api.DeleteIdentity)
})
})

r.With(api.requireAuthentication).Route("/factors", func(r *router) {
Expand Down
18 changes: 18 additions & 0 deletions internal/api/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const (
signatureKey = contextKey("signature")
externalProviderTypeKey = contextKey("external_provider_type")
userKey = contextKey("user")
targetUserKey = contextKey("target_user")
factorKey = contextKey("factor")
sessionKey = contextKey("session")
externalReferrerKey = contextKey("external_referrer")
Expand Down Expand Up @@ -76,6 +77,11 @@ func withUser(ctx context.Context, u *models.User) context.Context {
return context.WithValue(ctx, userKey, u)
}

// withTargetUser adds the target user for linking to the context.
func withTargetUser(ctx context.Context, u *models.User) context.Context {
return context.WithValue(ctx, targetUserKey, u)
}

// with Factor adds the factor id to the context.
func withFactor(ctx context.Context, f *models.Factor) context.Context {
return context.WithValue(ctx, factorKey, f)
Expand All @@ -93,6 +99,18 @@ func getUser(ctx context.Context) *models.User {
return obj.(*models.User)
}

// getTargetUser reads the user from the context.
func getTargetUser(ctx context.Context) *models.User {
if ctx == nil {
return nil
}
obj := ctx.Value(targetUserKey)
if obj == nil {
return nil
}
return obj.(*models.User)
}

// getFactor reads the factor id from the context
func getFactor(ctx context.Context) *models.Factor {
obj := ctx.Value(factorKey)
Expand Down
130 changes: 82 additions & 48 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,25 @@ import (
// ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow
type ExternalProviderClaims struct {
AuthMicroserviceClaims
Provider string `json:"provider"`
InviteToken string `json:"invite_token,omitempty"`
Referrer string `json:"referrer,omitempty"`
FlowStateID string `json:"flow_state_id"`
Provider string `json:"provider"`
InviteToken string `json:"invite_token,omitempty"`
Referrer string `json:"referrer,omitempty"`
FlowStateID string `json:"flow_state_id"`
LinkingTargetID string `json:"linking_target_id,omitempty"`
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved
}

// ExternalProviderRedirect redirects the request to the corresponding oauth provider
// ExternalProviderRedirect redirects the request to the oauth provider
func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) error {
rurl, err := a.GetExternalProviderRedirectURL(w, r, nil)
if err != nil {
return err
}
http.Redirect(w, r, rurl, http.StatusFound)
return nil
}

// GetExternalProviderRedirectURL returns the URL to start the oauth flow with the corresponding oauth provider
func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Request, linkingTargetUser *models.User) (string, error) {
ctx := r.Context()
db := a.db.WithContext(ctx)
config := a.config
Expand All @@ -45,45 +56,45 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e

p, err := a.Provider(ctx, providerType, scopes)
if err != nil {
return badRequestError("Unsupported provider: %+v", err).WithInternalError(err)
return "", badRequestError("Unsupported provider: %+v", err).WithInternalError(err)
}

inviteToken := query.Get("invite_token")
if inviteToken != "" {
_, userErr := models.FindUserByConfirmationToken(db, inviteToken)
if userErr != nil {
if models.IsNotFoundError(userErr) {
return notFoundError(userErr.Error())
return "", notFoundError(userErr.Error())
}
return internalServerError("Database error finding user").WithInternalError(userErr)
return "", internalServerError("Database error finding user").WithInternalError(userErr)
}
}

redirectURL := utilities.GetReferrer(r, config)
log := observability.GetLogEntry(r)
log.WithField("provider", providerType).Info("Redirecting to external provider")
if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil {
return err
return "", err
}
flowType := getFlowFromChallenge(codeChallenge)

flowStateID := ""
if flowType == models.PKCEFlow {
codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod)
if err != nil {
return err
return "", err
}
flowState, err := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth)
if err != nil {
return err
return "", err
}
if err := a.db.Create(flowState); err != nil {
return err
return "", err
}
flowStateID = flowState.ID.String()
}

token := jwt.NewWithClaims(jwt.SigningMethodHS256, ExternalProviderClaims{
claims := ExternalProviderClaims{
AuthMicroserviceClaims: AuthMicroserviceClaims{
StandardClaims: jwt.StandardClaims{
ExpiresAt: time.Now().Add(5 * time.Minute).Unix(),
Expand All @@ -95,10 +106,17 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e
InviteToken: inviteToken,
Referrer: redirectURL,
FlowStateID: flowStateID,
})
}

if linkingTargetUser != nil {
// this means that the user is performing manual linking
claims.LinkingTargetID = linkingTargetUser.ID.String()
}

token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(config.JWT.Secret))
if err != nil {
return internalServerError("Error creating state").WithInternalError(err)
return "", internalServerError("Error creating state").WithInternalError(err)
}

authUrlParams := make([]oauth2.AuthCodeOption, 0)
Expand All @@ -115,20 +133,15 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e
}
}

var authURL string
authURL := p.AuthCodeURL(tokenString, authUrlParams...)
switch externalProvider := p.(type) {
case *provider.TwitterProvider:
authURL = externalProvider.AuthCodeURL(tokenString, authUrlParams...)
err := storage.StoreInSession(providerType, externalProvider.Marshal(), r, w)
if err != nil {
return internalServerError("Error storing request token in session").WithInternalError(err)
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved
if err := storage.StoreInSession(providerType, externalProvider.Marshal(), r, w); err != nil {
return "", internalServerError("Error storing request token in session").WithInternalError(err)
}
default:
authURL = p.AuthCodeURL(tokenString, authUrlParams...)
}

http.Redirect(w, r, authURL, http.StatusFound)
return nil
return authURL, nil
}

// ExternalProviderCallback handles the callback endpoint in the external oauth provider flow
Expand All @@ -142,37 +155,41 @@ func (a *API) ExternalProviderCallback(w http.ResponseWriter, r *http.Request) e
return nil
}

func (a *API) handleOAuthCallback(w http.ResponseWriter, r *http.Request) (*OAuthProviderData, error) {
ctx := r.Context()
providerType := getExternalProviderType(ctx)

var oAuthResponseData *OAuthProviderData
var err error
switch providerType {
case "twitter":
// future OAuth1.0 providers will use this method
oAuthResponseData, err = a.oAuth1Callback(ctx, r, providerType)
default:
oAuthResponseData, err = a.oAuthCallback(ctx, r, providerType)
}
if err != nil {
return nil, err
}
return oAuthResponseData, nil
}

func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)
config := a.config

providerType := getExternalProviderType(ctx)
var userData *provider.UserProvidedData
var providerAccessToken string
var providerRefreshToken string
var grantParams models.GrantParams
var err error

grantParams.FillGrantParams(r)

if providerType == "twitter" {
// future OAuth1.0 providers will use this method
oAuthResponseData, err := a.oAuth1Callback(ctx, r, providerType)
if err != nil {
return err
}
userData = oAuthResponseData.userData
providerAccessToken = oAuthResponseData.token
} else {
oAuthResponseData, err := a.oAuthCallback(ctx, r, providerType)
if err != nil {
return err
}
userData = oAuthResponseData.userData
providerAccessToken = oAuthResponseData.token
providerRefreshToken = oAuthResponseData.refreshToken
providerType := getExternalProviderType(ctx)
data, err := a.handleOAuthCallback(w, r)
if err != nil {
return err
}
userData := data.userData
providerAccessToken := data.token
providerRefreshToken := data.refreshToken

var flowState *models.FlowState
// if there's a non-empty FlowStateID we perform PKCE Flow
Expand All @@ -187,8 +204,11 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
var token *AccessTokenResponse
err = db.Transaction(func(tx *storage.Connection) error {
var terr error
inviteToken := getInviteToken(ctx)
if inviteToken != "" {
if targetUser := getTargetUser(ctx); targetUser != nil {
if user, terr = a.linkIdentityToUser(ctx, tx, userData, providerType); terr != nil {
return terr
}
} else if inviteToken := getInviteToken(ctx); inviteToken != "" {
if user, terr = a.processInvite(r, ctx, tx, userData, inviteToken, providerType); terr != nil {
return terr
}
Expand Down Expand Up @@ -471,6 +491,20 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont
if claims.FlowStateID != "" {
ctx = withFlowStateID(ctx, claims.FlowStateID)
}
if claims.LinkingTargetID != "" {
linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID)
if err != nil {
return nil, badRequestError("invalid target user id")
}
u, err := models.FindUserByID(a.db, linkingTargetUserID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, notFoundError("Linking target user not found")
}
return nil, internalServerError("Database error loading user").WithInternalError(err)
}
ctx = withTargetUser(ctx, u)
}
ctx = withExternalProviderType(ctx, claims.Provider)
return withSignature(ctx, state), nil
}
Expand Down
45 changes: 44 additions & 1 deletion internal/api/identity.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package api

import (
"context"
"net/http"

"github.com/fatih/structs"
"github.com/go-chi/chi"
"github.com/gofrs/uuid"
"github.com/supabase/gotrue/internal/api/provider"
"github.com/supabase/gotrue/internal/models"
"github.com/supabase/gotrue/internal/storage"
)
Expand All @@ -29,7 +32,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error {

user := getUser(ctx)
if len(user.Identities) <= 1 {
return badRequestError("Cannot unlink identity from user. User must have at least 1 identity after unlinking")
return badRequestError("User must have at least 1 identity after unlinking")
}
var identityToBeDeleted *models.Identity
for i := range user.Identities {
Expand Down Expand Up @@ -65,3 +68,43 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error {

return sendJSON(w, http.StatusOK, map[string]interface{}{})
}

func (a *API) LinkIdentity(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
user := getUser(ctx)
rurl, err := a.GetExternalProviderRedirectURL(w, r, user)
if err != nil {
return err
}
skipHTTPRedirect := r.URL.Query().Get("skip_http_redirect") == "true"
if skipHTTPRedirect {
return sendJSON(w, http.StatusOK, map[string]interface{}{
"url": rurl,
})
}
http.Redirect(w, r, rurl, http.StatusFound)
return nil
}

func (a *API) linkIdentityToUser(ctx context.Context, tx *storage.Connection, userData *provider.UserProvidedData, providerType string) (*models.User, error) {
targetUser := getTargetUser(ctx)
identity, terr := models.FindIdentityByIdAndProvider(tx, userData.Metadata.Subject, providerType)
if terr != nil {
if !models.IsNotFoundError(terr) {
return nil, internalServerError("Database error finding identity for linking").WithInternalError(terr)
}
}
if identity != nil {
if identity.UserID == targetUser.ID {
return nil, badRequestError("Identity is already linked")
}
return nil, badRequestError("Identity is already linked to another user")
}
if _, terr := a.createNewIdentity(tx, targetUser, providerType, structs.Map(userData.Metadata)); terr != nil {
return nil, terr
}
if terr := targetUser.UpdateAppMetaDataProviders(tx); terr != nil {
return nil, terr
}
return targetUser, nil
}
Loading
Loading