From 513200996f3312fd5364db96e4e2a89f2670b5c6 Mon Sep 17 00:00:00 2001 From: Kang Ming <kang.ming1996@gmail.com> Date: Fri, 17 Nov 2023 17:31:02 -0500 Subject: [PATCH 01/11] feat: update primary key for identities table --- internal/api/admin_test.go | 2 +- internal/models/identity.go | 14 ++++++--- internal/models/user.go | 13 +++----- ...231117164230_add_id_pkey_identities.up.sql | 31 +++++++++++++++++++ 4 files changed, 46 insertions(+), 14 deletions(-) create mode 100644 migrations/20231117164230_add_id_pkey_identities.up.sql diff --git a/internal/api/admin_test.go b/internal/api/admin_test.go index c3f5615f6e..204a14eb89 100644 --- a/internal/api/admin_test.go +++ b/internal/api/admin_test.go @@ -438,7 +438,7 @@ func (ts *AdminTestSuite) TestAdminUserUpdate() { for _, identity := range u.Identities { // for email & phone identities, the providerId is the same as the userId - require.Equal(ts.T(), u.ID.String(), identity.ID) + require.Equal(ts.T(), u.ID.String(), identity.ProviderID) require.Equal(ts.T(), u.ID, identity.UserID) if identity.Provider == "email" { require.Equal(ts.T(), newEmail, identity.IdentityData["email"]) diff --git a/internal/models/identity.go b/internal/models/identity.go index 2e8f87a7dd..1ada9130e9 100644 --- a/internal/models/identity.go +++ b/internal/models/identity.go @@ -12,7 +12,12 @@ import ( ) type Identity struct { - ID string `json:"id" db:"id"` + // returned as identity_id in JSON for backward compatibility with the interface exposed by the client library + // see https://github.com/supabase/gotrue-js/blob/c9296bbc27a2f036af55c1f33fca5930704bd021/src/lib/types.ts#L230-L240 + ID uuid.UUID `json:"identity_id" db:"id"` + // returned as id in JSON for backward compatibility with the interface exposed by the client library + // see https://github.com/supabase/gotrue-js/blob/c9296bbc27a2f036af55c1f33fca5930704bd021/src/lib/types.ts#L230-L240 + ProviderID string `json:"id" db:"provider_id"` UserID uuid.UUID `json:"user_id" db:"user_id"` IdentityData JSONMap `json:"identity_data,omitempty" db:"identity_data"` Provider string `json:"provider" db:"provider"` @@ -35,7 +40,7 @@ func NewIdentity(user *User, provider string, identityData map[string]interface{ now := time.Now() identity := &Identity{ - ID: providerId.(string), + ProviderID: providerId.(string), UserID: user.ID, IdentityData: identityData, Provider: provider, @@ -63,7 +68,7 @@ func (i *Identity) IsForSSOProvider() bool { // FindIdentityById searches for an identity with the matching id and provider given. func FindIdentityByIdAndProvider(tx *storage.Connection, providerId, provider string) (*Identity, error) { identity := &Identity{} - if err := tx.Q().Where("id = ? AND provider = ?", providerId, provider).First(identity); err != nil { + if err := tx.Q().Where("provider_id = ? AND provider = ?", providerId, provider).First(identity); err != nil { if errors.Cause(err) == sql.ErrNoRows { return nil, IdentityNotFoundError{} } @@ -117,9 +122,8 @@ func (i *Identity) UpdateIdentityData(tx *storage.Connection, updates map[string } // pop doesn't support updates on tables with composite primary keys so we use a raw query here. return tx.RawQuery( - "update "+(&pop.Model{Value: Identity{}}).TableName()+" set identity_data = ? where provider = ? and id = ?", + "update "+(&pop.Model{Value: Identity{}}).TableName()+" set identity_data = ? where id = ?", i.IdentityData, - i.Provider, i.ID, ).Exec() } diff --git a/internal/models/user.go b/internal/models/user.go index b9debb1afd..7e415f096e 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -673,9 +673,7 @@ func (u *User) RemoveUnconfirmedIdentities(tx *storage.Connection, identity *Ide // finally, remove all identities except the current identity being authenticated for i := range u.Identities { - identityId := u.Identities[i].Provider + u.Identities[i].ID - identityIdToKeep := identity.Provider + identity.ID - if identityId != identityIdToKeep { + if u.Identities[i].ID != identity.ID { if terr := tx.Destroy(&u.Identities[i]); terr != nil { return terr } @@ -762,10 +760,9 @@ func (u *User) SoftDeleteUserIdentities(tx *storage.Connection) error { if err := tx.RawQuery( "update "+ (&pop.Model{Value: Identity{}}).TableName()+ - " set id = ? where id = ? and provider = ?", - obfuscateIdentityId(identity), + " set provider_id = ? where id = ?", + obfuscateIdentityProviderId(identity), identity.ID, - identity.Provider, ).Exec(); err != nil { return err } @@ -787,6 +784,6 @@ func obfuscatePhone(u *User, phone string) string { return obfuscateValue(u.ID, phone)[:15] } -func obfuscateIdentityId(identity *Identity) string { - return obfuscateValue(identity.UserID, identity.Provider+":"+identity.ID) +func obfuscateIdentityProviderId(identity *Identity) string { + return obfuscateValue(identity.UserID, identity.Provider+":"+identity.ProviderID) } diff --git a/migrations/20231117164230_add_id_pkey_identities.up.sql b/migrations/20231117164230_add_id_pkey_identities.up.sql new file mode 100644 index 0000000000..a0a4a67ce7 --- /dev/null +++ b/migrations/20231117164230_add_id_pkey_identities.up.sql @@ -0,0 +1,31 @@ +do $$ +begin + if not exists(select * + from information_schema.columns + where table_schema = '{{ index .Options "Namespace" }}' and table_name='identities' and column_name='provider_id') + then + alter table if exists {{ index .Options "Namespace" }}.identities + rename column id to provider_id; + end if; +end$$; + +alter table if exists {{ index .Options "Namespace" }}.identities + drop constraint if exists identities_pkey; + +alter table if exists {{ index .Options "Namespace" }}.identities + add column if not exists id uuid default gen_random_uuid() primary key; + +do $$ +begin + if not exists + (select constraint_name + from information_schema.table_constraints + where table_schema = '{{ index .Options "Namespace" }}' + and table_name = 'identities' + and constraint_name = 'identities_provider_id_provider_unique') + then + alter table if exists {{ index .Options "Namespace" }}.identities + add constraint identities_provider_id_provider_unique + unique(provider_id, provider); + end if; +end $$; From 4cea55c6d05d4b42a74da12e24f8991165dcf373 Mon Sep 17 00:00:00 2001 From: Kang Ming <kang.ming1996@gmail.com> Date: Mon, 20 Nov 2023 10:07:41 -0500 Subject: [PATCH 02/11] Update migrations/20231117164230_add_id_pkey_identities.up.sql Co-authored-by: Stojan Dimitrovski <sdimitrovski@gmail.com> --- migrations/20231117164230_add_id_pkey_identities.up.sql | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/migrations/20231117164230_add_id_pkey_identities.up.sql b/migrations/20231117164230_add_id_pkey_identities.up.sql index a0a4a67ce7..31ed280d3b 100644 --- a/migrations/20231117164230_add_id_pkey_identities.up.sql +++ b/migrations/20231117164230_add_id_pkey_identities.up.sql @@ -10,9 +10,7 @@ begin end$$; alter table if exists {{ index .Options "Namespace" }}.identities - drop constraint if exists identities_pkey; - -alter table if exists {{ index .Options "Namespace" }}.identities + drop constraint if exists identities_pkey, add column if not exists id uuid default gen_random_uuid() primary key; do $$ From b98471736507208a0712034288de35f94648b83f Mon Sep 17 00:00:00 2001 From: Kang Ming <kang.ming1996@gmail.com> Date: Mon, 20 Nov 2023 12:55:50 -0500 Subject: [PATCH 03/11] feat: unlink identity from user --- internal/api/api.go | 2 + internal/api/identity.go | 67 ++++++++++++++++++++++++++++++ internal/models/audit_log_entry.go | 1 + internal/models/user.go | 8 +++- 4 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 internal/api/identity.go diff --git a/internal/api/api.go b/internal/api/api.go index 78541ce279..d9016fd0b3 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -152,6 +152,8 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati r.With(api.requireAuthentication).Route("/user", func(r *router) { r.Get("/", api.UserGet) r.With(sharedLimiter).Put("/", api.UserUpdate) + + r.Delete("/identities/{identity_id}", api.DeleteIdentity) }) r.With(api.requireAuthentication).Route("/factors", func(r *router) { diff --git a/internal/api/identity.go b/internal/api/identity.go new file mode 100644 index 0000000000..769dec752b --- /dev/null +++ b/internal/api/identity.go @@ -0,0 +1,67 @@ +package api + +import ( + "net/http" + + "github.com/go-chi/chi" + "github.com/gofrs/uuid" + "github.com/supabase/gotrue/internal/models" + "github.com/supabase/gotrue/internal/storage" +) + +func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + claims := getClaims(ctx) + if claims == nil { + return badRequestError("Could not read claims") + } + + aud := a.requestAud(ctx, r) + if aud != claims.Audience { + return badRequestError("Token audience doesn't match request audience") + } + + identityID, err := uuid.FromString(chi.URLParam(r, "identity_id")) + if err != nil { + return badRequestError("identity_id must be an UUID") + } + + user := getUser(ctx) + if len(user.Identities) <= 1 { + return badRequestError("Cannot unlink identity from user. User must have at least 1 identity after unlinking") + } + var identityToBeDeleted *models.Identity + for i := range user.Identities { + identity := user.Identities[i] + if identity.ID == identityID { + identityToBeDeleted = &identity + break + } + } + if identityToBeDeleted == nil { + return badRequestError("Identity doesn't exist") + } + + err = a.db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, user, models.IdentityUnlinkAction, "", map[string]interface{}{ + "identity_id": identityToBeDeleted.ID, + "provider": identityToBeDeleted.Provider, + "provider_id": identityToBeDeleted.ProviderID, + }); terr != nil { + return internalServerError("Error recording audit log entry").WithInternalError(terr) + } + if terr := tx.Destroy(identityToBeDeleted); terr != nil { + return internalServerError("Database error deleting identity").WithInternalError(terr) + } + if terr := user.UpdateAppMetaDataProviders(tx); terr != nil { + return internalServerError("Database error updating user providers").WithInternalError(terr) + } + return nil + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, map[string]interface{}{}) +} diff --git a/internal/models/audit_log_entry.go b/internal/models/audit_log_entry.go index 98223038b8..74ad3d6668 100644 --- a/internal/models/audit_log_entry.go +++ b/internal/models/audit_log_entry.go @@ -40,6 +40,7 @@ const ( DeleteRecoveryCodesAction AuditAction = "recovery_codes_deleted" UpdateFactorAction AuditAction = "factor_updated" MFACodeLoginAction AuditAction = "mfa_code_login" + IdentityUnlinkAction AuditAction = "identity_unlinked" account auditLogType = "account" team auditLogType = "team" diff --git a/internal/models/user.go b/internal/models/user.go index 7e415f096e..ede05f3703 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -210,9 +210,13 @@ func (u *User) UpdateAppMetaDataProviders(tx *storage.Connection) error { if terr != nil { return terr } - return u.UpdateAppMetaData(tx, map[string]interface{}{ + payload := map[string]interface{}{ "providers": providers, - }) + } + if len(providers) > 0 { + payload["provider"] = providers[0] + } + return u.UpdateAppMetaData(tx, payload) } // SetEmail sets the user's email From 38fa243597452c5fe367a61c1770641e106314af Mon Sep 17 00:00:00 2001 From: Kang Ming <kang.ming1996@gmail.com> Date: Tue, 21 Nov 2023 18:11:02 -0500 Subject: [PATCH 04/11] feat: add manual linking apis --- internal/api/api.go | 1 + internal/api/context.go | 18 ++++++ internal/api/external.go | 131 +++++++++++++++++++++++++-------------- internal/api/identity.go | 36 +++++++++++ 4 files changed, 138 insertions(+), 48 deletions(-) diff --git a/internal/api/api.go b/internal/api/api.go index d9016fd0b3..99ce911803 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -154,6 +154,7 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati r.With(sharedLimiter).Put("/", api.UserUpdate) r.Delete("/identities/{identity_id}", api.DeleteIdentity) + r.Get("/identities/authorize", api.LinkIdentity) }) r.With(api.requireAuthentication).Route("/factors", func(r *router) { diff --git a/internal/api/context.go b/internal/api/context.go index 107377b50e..2cd73309fc 100644 --- a/internal/api/context.go +++ b/internal/api/context.go @@ -21,6 +21,7 @@ const ( signatureKey = contextKey("signature") externalProviderTypeKey = contextKey("external_provider_type") userKey = contextKey("user") + targetUserKey = contextKey("target_user") factorKey = contextKey("factor") sessionKey = contextKey("session") externalReferrerKey = contextKey("external_referrer") @@ -76,6 +77,11 @@ func withUser(ctx context.Context, u *models.User) context.Context { return context.WithValue(ctx, userKey, u) } +// withTargetUser adds the target user for linking to the context. +func withTargetUser(ctx context.Context, u *models.User) context.Context { + return context.WithValue(ctx, targetUserKey, u) +} + // with Factor adds the factor id to the context. func withFactor(ctx context.Context, f *models.Factor) context.Context { return context.WithValue(ctx, factorKey, f) @@ -93,6 +99,18 @@ func getUser(ctx context.Context) *models.User { return obj.(*models.User) } +// getTargetUser reads the user from the context. +func getTargetUser(ctx context.Context) *models.User { + if ctx == nil { + return nil + } + obj := ctx.Value(targetUserKey) + if obj == nil { + return nil + } + return obj.(*models.User) +} + // getFactor reads the factor id from the context func getFactor(ctx context.Context) *models.Factor { obj := ctx.Value(factorKey) diff --git a/internal/api/external.go b/internal/api/external.go index ea25c2af88..c976714e29 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -25,14 +25,25 @@ import ( // ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow type ExternalProviderClaims struct { AuthMicroserviceClaims - Provider string `json:"provider"` - InviteToken string `json:"invite_token,omitempty"` - Referrer string `json:"referrer,omitempty"` - FlowStateID string `json:"flow_state_id"` + Provider string `json:"provider"` + InviteToken string `json:"invite_token,omitempty"` + Referrer string `json:"referrer,omitempty"` + FlowStateID string `json:"flow_state_id"` + LinkingTargetID string `json:"linking_target_id,omitempty"` } -// ExternalProviderRedirect redirects the request to the corresponding oauth provider +// ExternalProviderRedirect redirects the request to the oauth provider func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) error { + rurl, err := a.GetExternalProviderRedirectURL(w, r) + if err != nil { + return err + } + http.Redirect(w, r, rurl, http.StatusFound) + return nil +} + +// GetExternalProviderRedirectURL returns the URL to start the oauth flow with the corresponding oauth provider +func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Request) (string, error) { ctx := r.Context() db := a.db.WithContext(ctx) config := a.config @@ -45,7 +56,7 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e p, err := a.Provider(ctx, providerType, scopes) if err != nil { - return badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return "", badRequestError("Unsupported provider: %+v", err).WithInternalError(err) } inviteToken := query.Get("invite_token") @@ -53,9 +64,9 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e _, userErr := models.FindUserByConfirmationToken(db, inviteToken) if userErr != nil { if models.IsNotFoundError(userErr) { - return notFoundError(userErr.Error()) + return "", notFoundError(userErr.Error()) } - return internalServerError("Database error finding user").WithInternalError(userErr) + return "", internalServerError("Database error finding user").WithInternalError(userErr) } } @@ -63,7 +74,7 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e log := observability.GetLogEntry(r) log.WithField("provider", providerType).Info("Redirecting to external provider") if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil { - return err + return "", err } flowType := getFlowFromChallenge(codeChallenge) @@ -71,19 +82,19 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e if flowType == models.PKCEFlow { codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod) if err != nil { - return err + return "", err } flowState, err := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth) if err != nil { - return err + return "", err } if err := a.db.Create(flowState); err != nil { - return err + return "", err } flowStateID = flowState.ID.String() } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, ExternalProviderClaims{ + claims := ExternalProviderClaims{ AuthMicroserviceClaims: AuthMicroserviceClaims{ StandardClaims: jwt.StandardClaims{ ExpiresAt: time.Now().Add(5 * time.Minute).Unix(), @@ -95,10 +106,18 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e InviteToken: inviteToken, Referrer: redirectURL, FlowStateID: flowStateID, - }) + } + + if strings.Contains(r.URL.Path, "identities") { + // this means that the user is performing manual linking + user := getUser(ctx) + claims.LinkingTargetID = user.ID.String() + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString([]byte(config.JWT.Secret)) if err != nil { - return internalServerError("Error creating state").WithInternalError(err) + return "", internalServerError("Error creating state").WithInternalError(err) } authUrlParams := make([]oauth2.AuthCodeOption, 0) @@ -115,20 +134,15 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e } } - var authURL string + authURL := p.AuthCodeURL(tokenString, authUrlParams...) switch externalProvider := p.(type) { case *provider.TwitterProvider: - authURL = externalProvider.AuthCodeURL(tokenString, authUrlParams...) - err := storage.StoreInSession(providerType, externalProvider.Marshal(), r, w) - if err != nil { - return internalServerError("Error storing request token in session").WithInternalError(err) + if err := storage.StoreInSession(providerType, externalProvider.Marshal(), r, w); err != nil { + return "", internalServerError("Error storing request token in session").WithInternalError(err) } - default: - authURL = p.AuthCodeURL(tokenString, authUrlParams...) } - http.Redirect(w, r, authURL, http.StatusFound) - return nil + return authURL, nil } // ExternalProviderCallback handles the callback endpoint in the external oauth provider flow @@ -142,37 +156,41 @@ func (a *API) ExternalProviderCallback(w http.ResponseWriter, r *http.Request) e return nil } +func (a *API) handleOAuthCallback(w http.ResponseWriter, r *http.Request) (*OAuthProviderData, error) { + ctx := r.Context() + providerType := getExternalProviderType(ctx) + + var oAuthResponseData *OAuthProviderData + var err error + switch providerType { + case "twitter": + // future OAuth1.0 providers will use this method + oAuthResponseData, err = a.oAuth1Callback(ctx, r, providerType) + default: + oAuthResponseData, err = a.oAuthCallback(ctx, r, providerType) + } + if err != nil { + return nil, err + } + return oAuthResponseData, nil +} + func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() db := a.db.WithContext(ctx) config := a.config - providerType := getExternalProviderType(ctx) - var userData *provider.UserProvidedData - var providerAccessToken string - var providerRefreshToken string var grantParams models.GrantParams - var err error - grantParams.FillGrantParams(r) - if providerType == "twitter" { - // future OAuth1.0 providers will use this method - oAuthResponseData, err := a.oAuth1Callback(ctx, r, providerType) - if err != nil { - return err - } - userData = oAuthResponseData.userData - providerAccessToken = oAuthResponseData.token - } else { - oAuthResponseData, err := a.oAuthCallback(ctx, r, providerType) - if err != nil { - return err - } - userData = oAuthResponseData.userData - providerAccessToken = oAuthResponseData.token - providerRefreshToken = oAuthResponseData.refreshToken + providerType := getExternalProviderType(ctx) + data, err := a.handleOAuthCallback(w, r) + if err != nil { + return err } + userData := data.userData + providerAccessToken := data.token + providerRefreshToken := data.refreshToken var flowState *models.FlowState // if there's a non-empty FlowStateID we perform PKCE Flow @@ -187,8 +205,11 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re var token *AccessTokenResponse err = db.Transaction(func(tx *storage.Connection) error { var terr error - inviteToken := getInviteToken(ctx) - if inviteToken != "" { + if targetUser := getTargetUser(ctx); targetUser != nil { + if user, terr = a.linkIdentityToUser(ctx, tx, userData, providerType); terr != nil { + return terr + } + } else if inviteToken := getInviteToken(ctx); inviteToken != "" { if user, terr = a.processInvite(r, ctx, tx, userData, inviteToken, providerType); terr != nil { return terr } @@ -471,6 +492,20 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont if claims.FlowStateID != "" { ctx = withFlowStateID(ctx, claims.FlowStateID) } + if claims.LinkingTargetID != "" { + linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID) + if err != nil { + return nil, badRequestError("invalid target user id") + } + u, err := models.FindUserByID(a.db, linkingTargetUserID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, notFoundError("Linking target user not found") + } + return nil, internalServerError("Database error loading user").WithInternalError(err) + } + ctx = withTargetUser(ctx, u) + } ctx = withExternalProviderType(ctx, claims.Provider) return withSignature(ctx, state), nil } diff --git a/internal/api/identity.go b/internal/api/identity.go index 769dec752b..247a17b2d1 100644 --- a/internal/api/identity.go +++ b/internal/api/identity.go @@ -1,10 +1,13 @@ package api import ( + "context" "net/http" + "github.com/fatih/structs" "github.com/go-chi/chi" "github.com/gofrs/uuid" + "github.com/supabase/gotrue/internal/api/provider" "github.com/supabase/gotrue/internal/models" "github.com/supabase/gotrue/internal/storage" ) @@ -65,3 +68,36 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { return sendJSON(w, http.StatusOK, map[string]interface{}{}) } + +func (a *API) LinkIdentity(w http.ResponseWriter, r *http.Request) error { + rurl, err := a.GetExternalProviderRedirectURL(w, r) + if err != nil { + return err + } + return sendJSON(w, http.StatusOK, map[string]interface{}{ + "url": rurl, + }) +} + +func (a *API) linkIdentityToUser(ctx context.Context, tx *storage.Connection, userData *provider.UserProvidedData, providerType string) (*models.User, error) { + targetUser := getTargetUser(ctx) + identity, terr := models.FindIdentityByIdAndProvider(tx, userData.Metadata.Subject, providerType) + if terr != nil { + if !models.IsNotFoundError(terr) { + return nil, internalServerError("Database error finding identity for linking").WithInternalError(terr) + } + } + if identity != nil { + if identity.UserID == uuid.Nil { + return nil, badRequestError("Identity is already linked") + } + return nil, badRequestError("Identity is already linked to another user") + } + if _, terr := a.createNewIdentity(tx, targetUser, providerType, structs.Map(userData.Metadata)); terr != nil { + return nil, terr + } + if terr := targetUser.UpdateAppMetaDataProviders(tx); terr != nil { + return nil, terr + } + return targetUser, nil +} From bd5454f21fcf0b6d37d7e8dd39e9b7b0321970e3 Mon Sep 17 00:00:00 2001 From: Kang Ming <kang.ming1996@gmail.com> Date: Wed, 22 Nov 2023 09:00:58 -0500 Subject: [PATCH 05/11] Update internal/api/api.go Co-authored-by: Stojan Dimitrovski <sdimitrovski@gmail.com> --- internal/api/api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/api/api.go b/internal/api/api.go index 99ce911803..af06d0cd1d 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -153,8 +153,8 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati r.Get("/", api.UserGet) r.With(sharedLimiter).Put("/", api.UserUpdate) - r.Delete("/identities/{identity_id}", api.DeleteIdentity) r.Get("/identities/authorize", api.LinkIdentity) + r.Delete("/identities/{identity_id}", api.DeleteIdentity) }) r.With(api.requireAuthentication).Route("/factors", func(r *router) { From f5f5de3628e551857b74dc58f58d542f1c9fb6f5 Mon Sep 17 00:00:00 2001 From: Kang Ming <kang.ming1996@gmail.com> Date: Wed, 22 Nov 2023 12:02:56 -0500 Subject: [PATCH 06/11] add skip_http_redirect option --- internal/api/identity.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/internal/api/identity.go b/internal/api/identity.go index 247a17b2d1..0b76e3a9c3 100644 --- a/internal/api/identity.go +++ b/internal/api/identity.go @@ -3,6 +3,7 @@ package api import ( "context" "net/http" + "strconv" "github.com/fatih/structs" "github.com/go-chi/chi" @@ -74,9 +75,17 @@ func (a *API) LinkIdentity(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } - return sendJSON(w, http.StatusOK, map[string]interface{}{ - "url": rurl, - }) + skipHTTPRedirect, err := strconv.ParseBool(r.URL.Query().Get("skip_http_redirect")) + if err != nil { + skipHTTPRedirect = false + } + if skipHTTPRedirect { + return sendJSON(w, http.StatusOK, map[string]interface{}{ + "url": rurl, + }) + } + http.Redirect(w, r, rurl, http.StatusFound) + return nil } func (a *API) linkIdentityToUser(ctx context.Context, tx *storage.Connection, userData *provider.UserProvidedData, providerType string) (*models.User, error) { From ada93d7d44c8c6650159ebf2331915bb7186b31f Mon Sep 17 00:00:00 2001 From: Kang Ming <kang.ming1996@gmail.com> Date: Wed, 22 Nov 2023 12:18:16 -0500 Subject: [PATCH 07/11] refactor linking logic --- internal/api/external.go | 9 ++++----- internal/api/identity.go | 4 +++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/internal/api/external.go b/internal/api/external.go index c976714e29..577ee78c9b 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -34,7 +34,7 @@ type ExternalProviderClaims struct { // ExternalProviderRedirect redirects the request to the oauth provider func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) error { - rurl, err := a.GetExternalProviderRedirectURL(w, r) + rurl, err := a.GetExternalProviderRedirectURL(w, r, nil) if err != nil { return err } @@ -43,7 +43,7 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e } // GetExternalProviderRedirectURL returns the URL to start the oauth flow with the corresponding oauth provider -func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Request) (string, error) { +func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Request, linkingTargetUser *models.User) (string, error) { ctx := r.Context() db := a.db.WithContext(ctx) config := a.config @@ -108,10 +108,9 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ FlowStateID: flowStateID, } - if strings.Contains(r.URL.Path, "identities") { + if linkingTargetUser != nil { // this means that the user is performing manual linking - user := getUser(ctx) - claims.LinkingTargetID = user.ID.String() + claims.LinkingTargetID = linkingTargetUser.ID.String() } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) diff --git a/internal/api/identity.go b/internal/api/identity.go index 0b76e3a9c3..c132d2f0dd 100644 --- a/internal/api/identity.go +++ b/internal/api/identity.go @@ -71,7 +71,9 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { } func (a *API) LinkIdentity(w http.ResponseWriter, r *http.Request) error { - rurl, err := a.GetExternalProviderRedirectURL(w, r) + ctx := r.Context() + user := getUser(ctx) + rurl, err := a.GetExternalProviderRedirectURL(w, r, user) if err != nil { return err } From 3b0b550552e0ec74da4a37f8461f7dd248af5224 Mon Sep 17 00:00:00 2001 From: Kang Ming <kang.ming1996@gmail.com> Date: Tue, 28 Nov 2023 00:05:09 -0800 Subject: [PATCH 08/11] add test for linkIdentityToUser --- internal/api/identity.go | 2 +- internal/api/identity_test.go | 77 +++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 internal/api/identity_test.go diff --git a/internal/api/identity.go b/internal/api/identity.go index c132d2f0dd..58727b3753 100644 --- a/internal/api/identity.go +++ b/internal/api/identity.go @@ -99,7 +99,7 @@ func (a *API) linkIdentityToUser(ctx context.Context, tx *storage.Connection, us } } if identity != nil { - if identity.UserID == uuid.Nil { + if identity.UserID == targetUser.ID { return nil, badRequestError("Identity is already linked") } return nil, badRequestError("Identity is already linked to another user") diff --git a/internal/api/identity_test.go b/internal/api/identity_test.go new file mode 100644 index 0000000000..a14b770a70 --- /dev/null +++ b/internal/api/identity_test.go @@ -0,0 +1,77 @@ +package api + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/gotrue/internal/api/provider" + "github.com/supabase/gotrue/internal/conf" + "github.com/supabase/gotrue/internal/models" +) + +type IdentityTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestIdentity(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + ts := &IdentityTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + suite.Run(t, ts) +} + +func (ts *IdentityTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Create user + u, err := models.NewUser("", "test@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + + // Create identity + i, err := models.NewIdentity(u, "email", map[string]interface{}{ + "sub": u.ID.String(), + "email": u.GetEmail(), + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(i)) +} + +func (ts *IdentityTestSuite) TestLinkIdentityToUser() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + ctx := withTargetUser(context.Background(), u) + + // link a valid identity + testValidUserData := &provider.UserProvidedData{ + Metadata: &provider.Claims{ + Subject: "test_subject", + }, + } + u, err = ts.API.linkIdentityToUser(ctx, ts.API.db, testValidUserData, "test") + require.NoError(ts.T(), err) + + // load associated identities for the user + ts.API.db.Load(u, "Identities") + require.Len(ts.T(), u.Identities, 2) + require.Equal(ts.T(), u.AppMetaData["provider"], "email") + require.Equal(ts.T(), u.AppMetaData["providers"], []string{"email", "test"}) + + // link an already existing identity + testExistingUserData := &provider.UserProvidedData{ + Metadata: &provider.Claims{ + Subject: u.ID.String(), + }, + } + u, err = ts.API.linkIdentityToUser(ctx, ts.API.db, testExistingUserData, "email") + require.ErrorIs(ts.T(), err, badRequestError("Identity is already linked")) + require.Nil(ts.T(), u) +} From 92ca1077bd370311481584032a8d801e60b12719 Mon Sep 17 00:00:00 2001 From: Kang Ming <kang.ming1996@gmail.com> Date: Wed, 29 Nov 2023 06:32:42 -0800 Subject: [PATCH 09/11] Update internal/api/identity.go Co-authored-by: Stojan Dimitrovski <sdimitrovski@gmail.com> --- internal/api/identity.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/internal/api/identity.go b/internal/api/identity.go index 58727b3753..28670f49bc 100644 --- a/internal/api/identity.go +++ b/internal/api/identity.go @@ -77,10 +77,7 @@ func (a *API) LinkIdentity(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } - skipHTTPRedirect, err := strconv.ParseBool(r.URL.Query().Get("skip_http_redirect")) - if err != nil { - skipHTTPRedirect = false - } + skipHTTPRedirect := r.URL.Query().Get("skip_http_redirect") == "true" if skipHTTPRedirect { return sendJSON(w, http.StatusOK, map[string]interface{}{ "url": rurl, From c3164f924d93d12d520408fc65d4347a14d9122f Mon Sep 17 00:00:00 2001 From: Kang Ming <kang.ming1996@gmail.com> Date: Wed, 29 Nov 2023 11:52:49 -0800 Subject: [PATCH 10/11] update error message & remove unused import --- internal/api/identity.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/api/identity.go b/internal/api/identity.go index 28670f49bc..b3b69faef8 100644 --- a/internal/api/identity.go +++ b/internal/api/identity.go @@ -3,7 +3,6 @@ package api import ( "context" "net/http" - "strconv" "github.com/fatih/structs" "github.com/go-chi/chi" @@ -33,7 +32,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { user := getUser(ctx) if len(user.Identities) <= 1 { - return badRequestError("Cannot unlink identity from user. User must have at least 1 identity after unlinking") + return badRequestError("User must have at least 1 identity after unlinking") } var identityToBeDeleted *models.Identity for i := range user.Identities { From dc82607e1e36e341c05d32c32ef66f63e9b4b5ad Mon Sep 17 00:00:00 2001 From: Kang Ming <kang.ming1996@gmail.com> Date: Wed, 29 Nov 2023 12:02:08 -0800 Subject: [PATCH 11/11] fix: add config for manual linking endpoints --- internal/api/api.go | 7 +++++-- internal/api/middleware.go | 8 ++++++++ internal/conf/configuration.go | 1 + 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/internal/api/api.go b/internal/api/api.go index af06d0cd1d..e7a7c8200d 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -153,8 +153,11 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati r.Get("/", api.UserGet) r.With(sharedLimiter).Put("/", api.UserUpdate) - r.Get("/identities/authorize", api.LinkIdentity) - r.Delete("/identities/{identity_id}", api.DeleteIdentity) + r.Route("/identities", func(r *router) { + r.Use(api.requireManualLinkingEnabled) + r.Get("/authorize", api.LinkIdentity) + r.Delete("/{identity_id}", api.DeleteIdentity) + }) }) r.With(api.requireAuthentication).Route("/factors", func(r *router) { diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 0ccb51a5a7..d529b89e03 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -233,6 +233,14 @@ func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (cont return ctx, nil } +func (a *API) requireManualLinkingEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + if !a.config.Security.ManualLinkingEnabled { + return nil, notFoundError("Manual linking is disabled") + } + return ctx, nil +} + func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index a3778a5c0c..6d7121889c 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -351,6 +351,7 @@ type SecurityConfiguration struct { RefreshTokenRotationEnabled bool `json:"refresh_token_rotation_enabled" split_words:"true" default:"true"` RefreshTokenReuseInterval int `json:"refresh_token_reuse_interval" split_words:"true"` UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"` + ManualLinkingEnabled bool `json:"manual_linking_enabled" split_words:"true" default:"false"` } func (c *SecurityConfiguration) Validate() error {