Skip to content

Commit e39ce7d

Browse files
kangmingtayhf
andauthored
feat: add manual linking APIs (supabase#1317)
## What kind of change does this PR introduce? * Adds a new endpoint `GET /user/identities/authorize` which is an endpoint to initiate the manual linking process and can only be invoked if the user is authenticated * `GET /user/identities/authorize` functions similarly to `GET /authorize` where the user needs to login to the new oauth identity in order to link the identity * Example ```curl // sign in with one of the supported auth methods to get the user's access token JWT first // start the identity linking process $ curl -X GET "http://localhost:9999/user/identities/authorize?provider=google" -H "Authorization: Bearer ACCESS_TOKEN_JWT" {"url":"https://oauth_provider_url.com/path/to/sign-in"} // visit the url returned and login to the oauth provider // request will be redirected to the /callback endpoint // if the identity is successfully linked, the request will be redirected to `http://localhost:3000/#access_token=xxx&....` // if the identity already exists, the request will be redirect to: // http://localhost:3000/?error=invalid_request&error_code=400&error_description=Identity+is+already+linked+to+another+user#error=invalid_request&error_code=400&error_description=Identity+is+already+linked+to+another+user ``` ## Details * The callback endpoint used will be the same callback as the oauth sign-in flow so that the developer doesn't have to add any additional callback URLs to the oauth provider in order to enable manual linking * A special field `LinkingTargetId` is introduced in the oauth state to store the linking target user ID. This ID will be used in the callback to determine the target user to link the candidate identity used * If the identity is already linked to the current user or another user, an error will be returned * If the identity doesn't exist, then it will be successfully linked to the existing user and a new access & refresh token will be issued. --------- Co-authored-by: Stojan Dimitrovski <[email protected]>
1 parent 9b7d7f6 commit e39ce7d

File tree

7 files changed

+235
-50
lines changed

7 files changed

+235
-50
lines changed

internal/api/api.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,11 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
153153
r.Get("/", api.UserGet)
154154
r.With(sharedLimiter).Put("/", api.UserUpdate)
155155

156-
r.Delete("/identities/{identity_id}", api.DeleteIdentity)
156+
r.Route("/identities", func(r *router) {
157+
r.Use(api.requireManualLinkingEnabled)
158+
r.Get("/authorize", api.LinkIdentity)
159+
r.Delete("/{identity_id}", api.DeleteIdentity)
160+
})
157161
})
158162

159163
r.With(api.requireAuthentication).Route("/factors", func(r *router) {

internal/api/context.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ const (
2121
signatureKey = contextKey("signature")
2222
externalProviderTypeKey = contextKey("external_provider_type")
2323
userKey = contextKey("user")
24+
targetUserKey = contextKey("target_user")
2425
factorKey = contextKey("factor")
2526
sessionKey = contextKey("session")
2627
externalReferrerKey = contextKey("external_referrer")
@@ -76,6 +77,11 @@ func withUser(ctx context.Context, u *models.User) context.Context {
7677
return context.WithValue(ctx, userKey, u)
7778
}
7879

80+
// withTargetUser adds the target user for linking to the context.
81+
func withTargetUser(ctx context.Context, u *models.User) context.Context {
82+
return context.WithValue(ctx, targetUserKey, u)
83+
}
84+
7985
// with Factor adds the factor id to the context.
8086
func withFactor(ctx context.Context, f *models.Factor) context.Context {
8187
return context.WithValue(ctx, factorKey, f)
@@ -93,6 +99,18 @@ func getUser(ctx context.Context) *models.User {
9399
return obj.(*models.User)
94100
}
95101

102+
// getTargetUser reads the user from the context.
103+
func getTargetUser(ctx context.Context) *models.User {
104+
if ctx == nil {
105+
return nil
106+
}
107+
obj := ctx.Value(targetUserKey)
108+
if obj == nil {
109+
return nil
110+
}
111+
return obj.(*models.User)
112+
}
113+
96114
// getFactor reads the factor id from the context
97115
func getFactor(ctx context.Context) *models.Factor {
98116
obj := ctx.Value(factorKey)

internal/api/external.go

Lines changed: 82 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,25 @@ import (
2525
// ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow
2626
type ExternalProviderClaims struct {
2727
AuthMicroserviceClaims
28-
Provider string `json:"provider"`
29-
InviteToken string `json:"invite_token,omitempty"`
30-
Referrer string `json:"referrer,omitempty"`
31-
FlowStateID string `json:"flow_state_id"`
28+
Provider string `json:"provider"`
29+
InviteToken string `json:"invite_token,omitempty"`
30+
Referrer string `json:"referrer,omitempty"`
31+
FlowStateID string `json:"flow_state_id"`
32+
LinkingTargetID string `json:"linking_target_id,omitempty"`
3233
}
3334

34-
// ExternalProviderRedirect redirects the request to the corresponding oauth provider
35+
// ExternalProviderRedirect redirects the request to the oauth provider
3536
func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) error {
37+
rurl, err := a.GetExternalProviderRedirectURL(w, r, nil)
38+
if err != nil {
39+
return err
40+
}
41+
http.Redirect(w, r, rurl, http.StatusFound)
42+
return nil
43+
}
44+
45+
// GetExternalProviderRedirectURL returns the URL to start the oauth flow with the corresponding oauth provider
46+
func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Request, linkingTargetUser *models.User) (string, error) {
3647
ctx := r.Context()
3748
db := a.db.WithContext(ctx)
3849
config := a.config
@@ -45,45 +56,45 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e
4556

4657
p, err := a.Provider(ctx, providerType, scopes)
4758
if err != nil {
48-
return badRequestError("Unsupported provider: %+v", err).WithInternalError(err)
59+
return "", badRequestError("Unsupported provider: %+v", err).WithInternalError(err)
4960
}
5061

5162
inviteToken := query.Get("invite_token")
5263
if inviteToken != "" {
5364
_, userErr := models.FindUserByConfirmationToken(db, inviteToken)
5465
if userErr != nil {
5566
if models.IsNotFoundError(userErr) {
56-
return notFoundError(userErr.Error())
67+
return "", notFoundError(userErr.Error())
5768
}
58-
return internalServerError("Database error finding user").WithInternalError(userErr)
69+
return "", internalServerError("Database error finding user").WithInternalError(userErr)
5970
}
6071
}
6172

6273
redirectURL := utilities.GetReferrer(r, config)
6374
log := observability.GetLogEntry(r)
6475
log.WithField("provider", providerType).Info("Redirecting to external provider")
6576
if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil {
66-
return err
77+
return "", err
6778
}
6879
flowType := getFlowFromChallenge(codeChallenge)
6980

7081
flowStateID := ""
7182
if flowType == models.PKCEFlow {
7283
codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod)
7384
if err != nil {
74-
return err
85+
return "", err
7586
}
7687
flowState, err := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth)
7788
if err != nil {
78-
return err
89+
return "", err
7990
}
8091
if err := a.db.Create(flowState); err != nil {
81-
return err
92+
return "", err
8293
}
8394
flowStateID = flowState.ID.String()
8495
}
8596

86-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, ExternalProviderClaims{
97+
claims := ExternalProviderClaims{
8798
AuthMicroserviceClaims: AuthMicroserviceClaims{
8899
StandardClaims: jwt.StandardClaims{
89100
ExpiresAt: time.Now().Add(5 * time.Minute).Unix(),
@@ -95,10 +106,17 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e
95106
InviteToken: inviteToken,
96107
Referrer: redirectURL,
97108
FlowStateID: flowStateID,
98-
})
109+
}
110+
111+
if linkingTargetUser != nil {
112+
// this means that the user is performing manual linking
113+
claims.LinkingTargetID = linkingTargetUser.ID.String()
114+
}
115+
116+
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
99117
tokenString, err := token.SignedString([]byte(config.JWT.Secret))
100118
if err != nil {
101-
return internalServerError("Error creating state").WithInternalError(err)
119+
return "", internalServerError("Error creating state").WithInternalError(err)
102120
}
103121

104122
authUrlParams := make([]oauth2.AuthCodeOption, 0)
@@ -115,20 +133,15 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e
115133
}
116134
}
117135

118-
var authURL string
136+
authURL := p.AuthCodeURL(tokenString, authUrlParams...)
119137
switch externalProvider := p.(type) {
120138
case *provider.TwitterProvider:
121-
authURL = externalProvider.AuthCodeURL(tokenString, authUrlParams...)
122-
err := storage.StoreInSession(providerType, externalProvider.Marshal(), r, w)
123-
if err != nil {
124-
return internalServerError("Error storing request token in session").WithInternalError(err)
139+
if err := storage.StoreInSession(providerType, externalProvider.Marshal(), r, w); err != nil {
140+
return "", internalServerError("Error storing request token in session").WithInternalError(err)
125141
}
126-
default:
127-
authURL = p.AuthCodeURL(tokenString, authUrlParams...)
128142
}
129143

130-
http.Redirect(w, r, authURL, http.StatusFound)
131-
return nil
144+
return authURL, nil
132145
}
133146

134147
// ExternalProviderCallback handles the callback endpoint in the external oauth provider flow
@@ -142,37 +155,41 @@ func (a *API) ExternalProviderCallback(w http.ResponseWriter, r *http.Request) e
142155
return nil
143156
}
144157

158+
func (a *API) handleOAuthCallback(w http.ResponseWriter, r *http.Request) (*OAuthProviderData, error) {
159+
ctx := r.Context()
160+
providerType := getExternalProviderType(ctx)
161+
162+
var oAuthResponseData *OAuthProviderData
163+
var err error
164+
switch providerType {
165+
case "twitter":
166+
// future OAuth1.0 providers will use this method
167+
oAuthResponseData, err = a.oAuth1Callback(ctx, r, providerType)
168+
default:
169+
oAuthResponseData, err = a.oAuthCallback(ctx, r, providerType)
170+
}
171+
if err != nil {
172+
return nil, err
173+
}
174+
return oAuthResponseData, nil
175+
}
176+
145177
func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Request) error {
146178
ctx := r.Context()
147179
db := a.db.WithContext(ctx)
148180
config := a.config
149181

150-
providerType := getExternalProviderType(ctx)
151-
var userData *provider.UserProvidedData
152-
var providerAccessToken string
153-
var providerRefreshToken string
154182
var grantParams models.GrantParams
155-
var err error
156-
157183
grantParams.FillGrantParams(r)
158184

159-
if providerType == "twitter" {
160-
// future OAuth1.0 providers will use this method
161-
oAuthResponseData, err := a.oAuth1Callback(ctx, r, providerType)
162-
if err != nil {
163-
return err
164-
}
165-
userData = oAuthResponseData.userData
166-
providerAccessToken = oAuthResponseData.token
167-
} else {
168-
oAuthResponseData, err := a.oAuthCallback(ctx, r, providerType)
169-
if err != nil {
170-
return err
171-
}
172-
userData = oAuthResponseData.userData
173-
providerAccessToken = oAuthResponseData.token
174-
providerRefreshToken = oAuthResponseData.refreshToken
185+
providerType := getExternalProviderType(ctx)
186+
data, err := a.handleOAuthCallback(w, r)
187+
if err != nil {
188+
return err
175189
}
190+
userData := data.userData
191+
providerAccessToken := data.token
192+
providerRefreshToken := data.refreshToken
176193

177194
var flowState *models.FlowState
178195
// if there's a non-empty FlowStateID we perform PKCE Flow
@@ -187,8 +204,11 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
187204
var token *AccessTokenResponse
188205
err = db.Transaction(func(tx *storage.Connection) error {
189206
var terr error
190-
inviteToken := getInviteToken(ctx)
191-
if inviteToken != "" {
207+
if targetUser := getTargetUser(ctx); targetUser != nil {
208+
if user, terr = a.linkIdentityToUser(ctx, tx, userData, providerType); terr != nil {
209+
return terr
210+
}
211+
} else if inviteToken := getInviteToken(ctx); inviteToken != "" {
192212
if user, terr = a.processInvite(r, ctx, tx, userData, inviteToken, providerType); terr != nil {
193213
return terr
194214
}
@@ -479,6 +499,20 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont
479499
if claims.FlowStateID != "" {
480500
ctx = withFlowStateID(ctx, claims.FlowStateID)
481501
}
502+
if claims.LinkingTargetID != "" {
503+
linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID)
504+
if err != nil {
505+
return nil, badRequestError("invalid target user id")
506+
}
507+
u, err := models.FindUserByID(a.db, linkingTargetUserID)
508+
if err != nil {
509+
if models.IsNotFoundError(err) {
510+
return nil, notFoundError("Linking target user not found")
511+
}
512+
return nil, internalServerError("Database error loading user").WithInternalError(err)
513+
}
514+
ctx = withTargetUser(ctx, u)
515+
}
482516
ctx = withExternalProviderType(ctx, claims.Provider)
483517
return withSignature(ctx, state), nil
484518
}

internal/api/identity.go

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package api
22

33
import (
4+
"context"
45
"net/http"
56

7+
"github.com/fatih/structs"
68
"github.com/go-chi/chi"
79
"github.com/gofrs/uuid"
10+
"github.com/supabase/gotrue/internal/api/provider"
811
"github.com/supabase/gotrue/internal/models"
912
"github.com/supabase/gotrue/internal/storage"
1013
)
@@ -29,7 +32,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error {
2932

3033
user := getUser(ctx)
3134
if len(user.Identities) <= 1 {
32-
return badRequestError("Cannot unlink identity from user. User must have at least 1 identity after unlinking")
35+
return badRequestError("User must have at least 1 identity after unlinking")
3336
}
3437
var identityToBeDeleted *models.Identity
3538
for i := range user.Identities {
@@ -65,3 +68,43 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error {
6568

6669
return sendJSON(w, http.StatusOK, map[string]interface{}{})
6770
}
71+
72+
func (a *API) LinkIdentity(w http.ResponseWriter, r *http.Request) error {
73+
ctx := r.Context()
74+
user := getUser(ctx)
75+
rurl, err := a.GetExternalProviderRedirectURL(w, r, user)
76+
if err != nil {
77+
return err
78+
}
79+
skipHTTPRedirect := r.URL.Query().Get("skip_http_redirect") == "true"
80+
if skipHTTPRedirect {
81+
return sendJSON(w, http.StatusOK, map[string]interface{}{
82+
"url": rurl,
83+
})
84+
}
85+
http.Redirect(w, r, rurl, http.StatusFound)
86+
return nil
87+
}
88+
89+
func (a *API) linkIdentityToUser(ctx context.Context, tx *storage.Connection, userData *provider.UserProvidedData, providerType string) (*models.User, error) {
90+
targetUser := getTargetUser(ctx)
91+
identity, terr := models.FindIdentityByIdAndProvider(tx, userData.Metadata.Subject, providerType)
92+
if terr != nil {
93+
if !models.IsNotFoundError(terr) {
94+
return nil, internalServerError("Database error finding identity for linking").WithInternalError(terr)
95+
}
96+
}
97+
if identity != nil {
98+
if identity.UserID == targetUser.ID {
99+
return nil, badRequestError("Identity is already linked")
100+
}
101+
return nil, badRequestError("Identity is already linked to another user")
102+
}
103+
if _, terr := a.createNewIdentity(tx, targetUser, providerType, structs.Map(userData.Metadata)); terr != nil {
104+
return nil, terr
105+
}
106+
if terr := targetUser.UpdateAppMetaDataProviders(tx); terr != nil {
107+
return nil, terr
108+
}
109+
return targetUser, nil
110+
}

0 commit comments

Comments
 (0)